Source code for pipecat.services.aws.stt

#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import json
import os
import random
import string
from typing import AsyncGenerator, Optional

from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    ErrorFrame,
    Frame,
    InterimTranscriptionFrame,
    StartFrame,
    TranscriptionFrame,
)
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt

try:
    import websockets
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
    raise Exception(f"Missing module: {e}")


[docs] class AWSTranscribeSTTService(STTService): def __init__( self, *, api_key: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_session_token: Optional[str] = None, region: Optional[str] = "us-east-1", sample_rate: int = 16000, language: Language = Language.EN, **kwargs, ): super().__init__(**kwargs) self._settings = { "sample_rate": sample_rate, "language": language, "media_encoding": "linear16", # AWS expects raw PCM "number_of_channels": 1, "show_speaker_label": False, "enable_channel_identification": False, } # Validate sample rate - AWS Transcribe only supports 8000 Hz or 16000 Hz if sample_rate not in [8000, 16000]: logger.warning( f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. Converting from {sample_rate} Hz to 16000 Hz." ) self._settings["sample_rate"] = 16000 self._credentials = { "aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), "aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"), "aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"), "region": region or os.getenv("AWS_REGION", "us-east-1"), } self._ws_client = None self._connection_lock = asyncio.Lock() self._connecting = False self._receive_task = None
[docs] def get_service_encoding(self, encoding: str) -> str: """Convert internal encoding format to AWS Transcribe format.""" encoding_map = { "linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM } return encoding_map.get(encoding, encoding)
[docs] async def start(self, frame: StartFrame): """Initialize the connection when the service starts.""" await super().start(frame) logger.info("Starting AWS Transcribe service...") retry_count = 0 max_retries = 3 while retry_count < max_retries: try: await self._connect() if self._ws_client and self._ws_client.open: logger.info("Successfully established WebSocket connection") return logger.warning("WebSocket connection not established after connect") except Exception as e: logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}") retry_count += 1 if retry_count < max_retries: await asyncio.sleep(1) # Wait before retrying raise RuntimeError("Failed to establish WebSocket connection after multiple attempts")
[docs] async def stop(self, frame: EndFrame): await super().stop(frame) await self._disconnect()
[docs] async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._disconnect()
[docs] async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Process audio data and send to AWS Transcribe""" try: # Ensure WebSocket is connected if not self._ws_client or not self._ws_client.open: logger.debug("WebSocket not connected, attempting to reconnect...") try: await self._connect() except Exception as e: logger.error(f"Failed to reconnect: {e}") yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False) return # Format the audio data according to AWS event stream format event_message = build_event_message(audio) # Send the formatted event message try: await self._ws_client.send(event_message) # Start metrics after first chunk sent await self.start_processing_metrics() await self.start_ttfb_metrics() except websockets.exceptions.ConnectionClosed as e: logger.warning(f"Connection closed while sending: {e}") await self._disconnect() # Don't yield error here - we'll retry on next frame except Exception as e: logger.error(f"Error sending audio: {e}") yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False) await self._disconnect() except Exception as e: logger.error(f"Error in run_stt: {e}") yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False) await self._disconnect()
async def _connect(self): """Connect to AWS Transcribe with connection state management.""" if self._ws_client and self._ws_client.open and self._receive_task: logger.debug(f"{self} Already connected") return async with self._connection_lock: if self._connecting: logger.debug(f"{self} Connection already in progress") return try: self._connecting = True logger.debug(f"{self} Starting connection process...") if self._ws_client: await self._disconnect() language_code = self.language_to_service_language( Language(self._settings["language"]) ) if not language_code: raise ValueError(f"Unsupported language: {self._settings['language']}") # Generate random websocket key websocket_key = "".join( random.choices( string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20 ) ) # Add required headers extra_headers = { "Origin": "https://localhost", "Sec-WebSocket-Key": websocket_key, "Sec-WebSocket-Version": "13", "Connection": "keep-alive", } # Get presigned URL presigned_url = get_presigned_url( region=self._credentials["region"], credentials={ "access_key": self._credentials["aws_access_key_id"], "secret_key": self._credentials["aws_secret_access_key"], "session_token": self._credentials["aws_session_token"], }, language_code=language_code, media_encoding=self.get_service_encoding( self._settings["media_encoding"] ), # Convert to AWS format sample_rate=self._settings["sample_rate"], number_of_channels=self._settings["number_of_channels"], enable_partial_results_stabilization=True, partial_results_stability="high", show_speaker_label=self._settings["show_speaker_label"], enable_channel_identification=self._settings["enable_channel_identification"], ) logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...") # Connect with the required headers and settings self._ws_client = await websockets.connect( presigned_url, extra_headers=extra_headers, subprotocols=["mqtt"], ping_interval=None, ping_timeout=None, compression=None, ) logger.debug(f"{self} WebSocket connected, starting receive task...") # Start receive task self._receive_task = self.create_task(self._receive_loop()) logger.info(f"{self} Successfully connected to AWS Transcribe") except Exception as e: logger.error(f"{self} Failed to connect to AWS Transcribe: {e}") await self._disconnect() raise finally: self._connecting = False async def _disconnect(self): """Disconnect from AWS Transcribe.""" if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None try: if self._ws_client and self._ws_client.open: # Send end-stream message end_stream = {"message-type": "event", "event": "end"} await self._ws_client.send(json.dumps(end_stream)) await self._ws_client.close() except Exception as e: logger.warning(f"{self} Error closing WebSocket connection: {e}") finally: self._ws_client = None
[docs] def language_to_service_language(self, language: Language) -> str | None: """Convert internal language enum to AWS Transcribe language code.""" language_map = { Language.EN: "en-US", Language.ES: "es-US", Language.FR: "fr-FR", Language.DE: "de-DE", Language.IT: "it-IT", Language.PT: "pt-BR", Language.JA: "ja-JP", Language.KO: "ko-KR", Language.ZH: "zh-CN", Language.PL: "pl-PL", } return language_map.get(language)
@traced_stt async def _handle_transcription( self, transcript: str, is_final: bool, language: Optional[str] = None ): pass async def _receive_loop(self): """Background task to receive and process messages from AWS Transcribe.""" while True: if not self._ws_client or not self._ws_client.open: logger.warning(f"{self} WebSocket closed in receive loop") break try: response = await self._ws_client.recv() self.start_watchdog() headers, payload = decode_event(response) if headers.get(":message-type") == "event": # Process transcription results results = payload.get("Transcript", {}).get("Results", []) if results: result = results[0] alternatives = result.get("Alternatives", []) if alternatives: transcript = alternatives[0].get("Transcript", "") is_final = not result.get("IsPartial", True) if transcript: await self.stop_ttfb_metrics() if is_final: await self.push_frame( TranscriptionFrame( transcript, "", time_now_iso8601(), self._settings["language"], result=result, ) ) await self._handle_transcription( transcript, is_final, self._settings["language"], ) await self.stop_processing_metrics() else: await self.push_frame( InterimTranscriptionFrame( transcript, "", time_now_iso8601(), self._settings["language"], result=result, ) ) elif headers.get(":message-type") == "exception": error_msg = payload.get("Message", "Unknown error") logger.error(f"{self} Exception from AWS: {error_msg}") await self.push_frame( ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False) ) else: logger.debug(f"{self} Other message type received: {headers}") logger.debug(f"{self} Payload: {payload}") except websockets.exceptions.ConnectionClosed as e: logger.error( f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}" ) break except Exception as e: logger.error(f"{self} Unexpected error in receive loop: {e}") break finally: self.reset_watchdog()