Source code for pipecat.services.stt_service

#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Base classes for Speech-to-Text services with continuous and segmented processing."""

import io
import wave
from abc import abstractmethod
from typing import Any, AsyncGenerator, Dict, Mapping, Optional

from loguru import logger

from pipecat.frames.frames import (
    AudioRawFrame,
    Frame,
    StartFrame,
    STTMuteFrame,
    STTUpdateSettingsFrame,
    UserStartedSpeakingFrame,
    UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.transcriptions.language import Language


[docs] class STTService(AIService): """Base class for speech-to-text services. Provides common functionality for STT services including audio passthrough, muting, settings management, and audio processing. Subclasses must implement the run_stt method to provide actual speech recognition. Args: audio_passthrough: Whether to pass audio frames downstream after processing. Defaults to True. sample_rate: The sample rate for audio input. If None, will be determined from the start frame. **kwargs: Additional arguments passed to the parent AIService. """ def __init__( self, audio_passthrough=True, # STT input sample rate sample_rate: Optional[int] = None, **kwargs, ): super().__init__(**kwargs) self._audio_passthrough = audio_passthrough self._init_sample_rate = sample_rate self._sample_rate = 0 self._settings: Dict[str, Any] = {} self._muted: bool = False @property def is_muted(self) -> bool: """Check if the STT service is currently muted. Returns: True if the service is muted and will not process audio. """ return self._muted @property def sample_rate(self) -> int: """Get the current sample rate for audio processing. Returns: The sample rate in Hz. """ return self._sample_rate
[docs] async def set_model(self, model: str): """Set the speech recognition model. Args: model: The name of the model to use for speech recognition. """ self.set_model_name(model)
[docs] async def set_language(self, language: Language): """Set the language for speech recognition. Args: language: The language to use for speech recognition. """ pass
[docs] @abstractmethod async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Run speech-to-text on the provided audio data. This method must be implemented by subclasses to provide actual speech recognition functionality. Args: audio: Raw audio bytes to transcribe. Yields: Frame: Frames containing transcription results (typically TextFrame). """ pass
[docs] async def start(self, frame: StartFrame): """Start the STT service. Args: frame: The start frame containing initialization parameters. """ await super().start(frame) self._sample_rate = self._init_sample_rate or frame.audio_in_sample_rate
async def _update_settings(self, settings: Mapping[str, Any]): logger.info(f"Updating STT settings: {self._settings}") for key, value in settings.items(): if key in self._settings: logger.info(f"Updating STT setting {key} to: [{value}]") self._settings[key] = value if key == "language": await self.set_language(value) elif key == "model": self.set_model_name(value) else: logger.warning(f"Unknown setting for STT service: {key}")
[docs] async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection): """Process an audio frame for speech recognition. Args: frame: The audio frame to process. direction: The direction of frame processing. """ if self._muted: return await self.process_generator(self.run_stt(frame.audio))
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): """Process frames, handling VAD events and audio segmentation. Args: frame: The frame to process. direction: The direction of frame processing. """ await super().process_frame(frame, direction) if isinstance(frame, AudioRawFrame): # In this service we accumulate audio internally and at the end we # push a TextFrame. We also push audio downstream in case someone # else needs it. await self.process_audio_frame(frame, direction) if self._audio_passthrough: await self.push_frame(frame, direction) elif isinstance(frame, STTUpdateSettingsFrame): await self._update_settings(frame.settings) elif isinstance(frame, STTMuteFrame): self._muted = frame.mute logger.debug(f"STT service {'muted' if frame.mute else 'unmuted'}") else: await self.push_frame(frame, direction)
[docs] class SegmentedSTTService(STTService): """STT service that processes speech in segments using VAD events. Uses Voice Activity Detection (VAD) events to detect speech segments and runs speech-to-text only on those segments, rather than continuously. Requires VAD to be enabled in the pipeline to function properly. Maintains a small audio buffer to account for the delay between actual speech start and VAD detection. Args: sample_rate: The sample rate for audio input. If None, will be determined from the start frame. **kwargs: Additional arguments passed to the parent STTService. """ def __init__(self, *, sample_rate: Optional[int] = None, **kwargs): super().__init__(sample_rate=sample_rate, **kwargs) self._content = None self._wave = None self._audio_buffer = bytearray() self._audio_buffer_size_1s = 0 self._user_speaking = False
[docs] async def start(self, frame: StartFrame): """Start the segmented STT service and initialize audio buffer. Args: frame: The start frame containing initialization parameters. """ await super().start(frame) self._audio_buffer_size_1s = self.sample_rate * 2
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): """Process frames, handling VAD events and audio segmentation.""" await super().process_frame(frame, direction) if isinstance(frame, UserStartedSpeakingFrame): await self._handle_user_started_speaking(frame) elif isinstance(frame, UserStoppedSpeakingFrame): await self._handle_user_stopped_speaking(frame)
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame): if frame.emulated: return self._user_speaking = True async def _handle_user_stopped_speaking(self, frame: UserStoppedSpeakingFrame): if frame.emulated: return self._user_speaking = False content = io.BytesIO() wav = wave.open(content, "wb") wav.setsampwidth(2) wav.setnchannels(1) wav.setframerate(self.sample_rate) wav.writeframes(self._audio_buffer) wav.close() content.seek(0) await self.process_generator(self.run_stt(content.read())) # Start clean. self._audio_buffer.clear()
[docs] async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection): """Process audio frames by buffering them for segmented transcription. Continuously buffers audio, growing the buffer while user is speaking and maintaining a small buffer when not speaking to account for VAD delay. Args: frame: The audio frame to process. direction: The direction of frame processing. """ # If the user is speaking the audio buffer will keep growing. self._audio_buffer += frame.audio # If the user is not speaking we keep just a little bit of audio. if not self._user_speaking and len(self._audio_buffer) > self._audio_buffer_size_1s: discarded = len(self._audio_buffer) - self._audio_buffer_size_1s self._audio_buffer = self._audio_buffer[discarded:]