Source code for pipecat.services.assemblyai.stt

#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import json
from typing import Any, AsyncGenerator, Dict
from urllib.parse import urlencode

from loguru import logger

from pipecat import __version__ as pipecat_version
from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    Frame,
    InterimTranscriptionFrame,
    StartFrame,
    TranscriptionFrame,
    UserStartedSpeakingFrame,
    UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt

from .models import (
    AssemblyAIConnectionParams,
    BaseMessage,
    BeginMessage,
    TerminationMessage,
    TurnMessage,
)

try:
    import websockets
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error('In order to use AssemblyAI, you need to `pip install "pipecat-ai[assemblyai]"`.')
    raise Exception(f"Missing module: {e}")


[docs] class AssemblyAISTTService(STTService): def __init__( self, *, api_key: str, language: Language = Language.EN, # AssemblyAI only supports English api_endpoint_base_url: str = "wss://streaming.assemblyai.com/v3/ws", connection_params: AssemblyAIConnectionParams = AssemblyAIConnectionParams(), vad_force_turn_endpoint: bool = True, **kwargs, ): self._api_key = api_key self._language = language self._api_endpoint_base_url = api_endpoint_base_url self._connection_params = connection_params self._vad_force_turn_endpoint = vad_force_turn_endpoint super().__init__(sample_rate=self._connection_params.sample_rate, **kwargs) self._websocket = None self._termination_event = asyncio.Event() self._received_termination = False self._connected = False self._receive_task = None self._audio_buffer = bytearray() self._chunk_size_ms = 50 self._chunk_size_bytes = 0
[docs] def can_generate_metrics(self) -> bool: return True
[docs] async def start(self, frame: StartFrame): await super().start(frame) self._chunk_size_bytes = int(self._chunk_size_ms * self._sample_rate * 2 / 1000) await self._connect()
[docs] async def stop(self, frame: EndFrame): await super().stop(frame) await self._disconnect()
[docs] async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._disconnect()
[docs] async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: self._audio_buffer.extend(audio) while len(self._audio_buffer) >= self._chunk_size_bytes: chunk = bytes(self._audio_buffer[: self._chunk_size_bytes]) self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :] await self._websocket.send(chunk) yield None
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) if isinstance(frame, UserStartedSpeakingFrame): await self.start_ttfb_metrics() elif isinstance(frame, UserStoppedSpeakingFrame): if self._vad_force_turn_endpoint: await self._websocket.send(json.dumps({"type": "ForceEndpoint"})) await self.start_processing_metrics()
@traced_stt async def _trace_transcription(self, transcript: str, is_final: bool, language: Language): """Record transcription event for tracing.""" pass def _build_ws_url(self) -> str: """Build WebSocket URL with query parameters using urllib.parse.urlencode.""" params = { k: str(v).lower() if isinstance(v, bool) else v for k, v in self._connection_params.model_dump().items() if v is not None } if params: query_string = urlencode(params) return f"{self._api_endpoint_base_url}?{query_string}" return self._api_endpoint_base_url async def _connect(self): try: ws_url = self._build_ws_url() headers = { "Authorization": self._api_key, "User-Agent": f"AssemblyAI/1.0 (integration=Pipecat/{pipecat_version})", } self._websocket = await websockets.connect( ws_url, extra_headers=headers, ) self._connected = True self._receive_task = self.create_task(self._receive_task_handler()) except Exception as e: logger.error(f"Failed to connect to AssemblyAI: {e}") self._connected = False raise async def _disconnect(self): """Disconnect from AssemblyAI WebSocket and wait for termination message.""" if not self._connected or not self._websocket: return try: self._termination_event.clear() self._received_termination = False if len(self._audio_buffer) > 0: await self._websocket.send(bytes(self._audio_buffer)) self._audio_buffer.clear() try: await self._websocket.send(json.dumps({"type": "Terminate"})) try: await asyncio.wait_for( self._termination_event.wait(), timeout=5.0, ) except asyncio.TimeoutError: logger.warning("Timed out waiting for termination message from server") except Exception as e: logger.warning(f"Error during termination handshake: {e}") if self._receive_task: await self.cancel_task(self._receive_task) await self._websocket.close() except Exception as e: logger.error(f"Error during disconnect: {e}") finally: self._websocket = None self._connected = False self._receive_task = None async def _receive_task_handler(self): """Handle incoming WebSocket messages.""" try: while self._connected: try: message = await self._websocket.recv() self.start_watchdog() data = json.loads(message) await self._handle_message(data) except websockets.exceptions.ConnectionClosedOK: break except Exception as e: logger.error(f"Error processing WebSocket message: {e}") break finally: self.reset_watchdog() except Exception as e: logger.error(f"Fatal error in receive handler: {e}") def _parse_message(self, message: Dict[str, Any]) -> BaseMessage: """Parse a raw message into the appropriate message type.""" msg_type = message.get("type") if msg_type == "Begin": return BeginMessage.model_validate(message) elif msg_type == "Turn": return TurnMessage.model_validate(message) elif msg_type == "Termination": return TerminationMessage.model_validate(message) else: raise ValueError(f"Unknown message type: {msg_type}") async def _handle_message(self, message: Dict[str, Any]): """Handle AssemblyAI WebSocket messages.""" try: parsed_message = self._parse_message(message) if isinstance(parsed_message, BeginMessage): logger.debug( f"Session Begin: {parsed_message.id} (expires at {parsed_message.expires_at})" ) elif isinstance(parsed_message, TurnMessage): await self._handle_transcription(parsed_message) elif isinstance(parsed_message, TerminationMessage): await self._handle_termination(parsed_message) except Exception as e: logger.error(f"Error handling message: {e}") async def _handle_termination(self, message: TerminationMessage): """Handle termination message.""" self._received_termination = True self._termination_event.set() logger.info( f"Session Terminated: Audio Duration={message.audio_duration_seconds}s, " f"Session Duration={message.session_duration_seconds}s" ) await self.push_frame(EndFrame()) async def _handle_transcription(self, message: TurnMessage): """Handle transcription results.""" if not message.transcript: return await self.stop_ttfb_metrics() if message.end_of_turn and ( not self._connection_params.formatted_finals or message.turn_is_formatted ): await self.push_frame( TranscriptionFrame( message.transcript, "", # participant time_now_iso8601(), self._language, message, ) ) await self._trace_transcription(message.transcript, True, self._language) await self.stop_processing_metrics() else: await self.push_frame( InterimTranscriptionFrame( message.transcript, "", # participant time_now_iso8601(), self._language, message, ) )