Source code for pipecat.services.whisper.stt

#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""This module implements Whisper transcription with a locally-downloaded model."""

import asyncio
from enum import Enum
from typing import AsyncGenerator, Optional

import numpy as np
from loguru import logger
from typing_extensions import TYPE_CHECKING, override

from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
from pipecat.services.stt_service import SegmentedSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt

if TYPE_CHECKING:
    try:
        from faster_whisper import WhisperModel
    except ModuleNotFoundError as e:
        logger.error(f"Exception: {e}")
        logger.error("In order to use Whisper, you need to `pip install pipecat-ai[whisper]`.")
        raise Exception(f"Missing module: {e}")

    try:
        import mlx_whisper
    except ModuleNotFoundError as e:
        logger.error(f"Exception: {e}")
        logger.error("In order to use Whisper, you need to `pip install pipecat-ai[mlx-whisper]`.")
        raise Exception(f"Missing module: {e}")


[docs] class Model(Enum): """Class of basic Whisper model selection options. Available models: Multilingual models: TINY: Smallest multilingual model BASE: Basic multilingual model MEDIUM: Good balance for multilingual LARGE: Best quality multilingual DISTIL_LARGE_V2: Fast multilingual English-only models: DISTIL_MEDIUM_EN: Fast English-only """ # Multilingual models TINY = "tiny" BASE = "base" MEDIUM = "medium" LARGE = "large-v3" DISTIL_LARGE_V2 = "Systran/faster-distil-whisper-large-v2" # English-only models DISTIL_MEDIUM_EN = "Systran/faster-distil-whisper-medium.en"
[docs] class MLXModel(Enum): """Class of MLX Whisper model selection options. Available models: Multilingual models: TINY: Smallest multilingual model MEDIUM: Good balance for multilingual LARGE_V3: Best quality multilingual LARGE_V3_TURBO: Finetuned, pruned Whisper large-v3, much faster, slightly lower quality DISTIL_LARGE_V3: Fast multilingual LARGE_V3_TURBO_Q4: LARGE_V3_TURBO, quantized to Q4 """ # Multilingual models TINY = "mlx-community/whisper-tiny" MEDIUM = "mlx-community/whisper-medium-mlx" LARGE_V3 = "mlx-community/whisper-large-v3-mlx" LARGE_V3_TURBO = "mlx-community/whisper-large-v3-turbo" DISTIL_LARGE_V3 = "mlx-community/distil-whisper-large-v3" LARGE_V3_TURBO_Q4 = "mlx-community/whisper-large-v3-turbo-q4"
[docs] def language_to_whisper_language(language: Language) -> Optional[str]: """Maps pipecat Language enum to Whisper language codes. Args: language: A Language enum value representing the input language. Returns: str or None: The corresponding Whisper language code, or None if not supported. Note: Only includes languages officially supported by Whisper. """ language_map = { # Arabic Language.AR: "ar", Language.AR_AE: "ar", Language.AR_BH: "ar", Language.AR_DZ: "ar", Language.AR_EG: "ar", Language.AR_IQ: "ar", Language.AR_JO: "ar", Language.AR_KW: "ar", Language.AR_LB: "ar", Language.AR_LY: "ar", Language.AR_MA: "ar", Language.AR_OM: "ar", Language.AR_QA: "ar", Language.AR_SA: "ar", Language.AR_SY: "ar", Language.AR_TN: "ar", Language.AR_YE: "ar", # Bengali Language.BN: "bn", Language.BN_BD: "bn", Language.BN_IN: "bn", # Czech Language.CS: "cs", Language.CS_CZ: "cs", # Danish Language.DA: "da", Language.DA_DK: "da", # German Language.DE: "de", Language.DE_AT: "de", Language.DE_CH: "de", Language.DE_DE: "de", # Greek Language.EL: "el", Language.EL_GR: "el", # English Language.EN: "en", Language.EN_AU: "en", Language.EN_CA: "en", Language.EN_GB: "en", Language.EN_HK: "en", Language.EN_IE: "en", Language.EN_IN: "en", Language.EN_KE: "en", Language.EN_NG: "en", Language.EN_NZ: "en", Language.EN_PH: "en", Language.EN_SG: "en", Language.EN_TZ: "en", Language.EN_US: "en", Language.EN_ZA: "en", # Spanish Language.ES: "es", Language.ES_AR: "es", Language.ES_BO: "es", Language.ES_CL: "es", Language.ES_CO: "es", Language.ES_CR: "es", Language.ES_CU: "es", Language.ES_DO: "es", Language.ES_EC: "es", Language.ES_ES: "es", Language.ES_GQ: "es", Language.ES_GT: "es", Language.ES_HN: "es", Language.ES_MX: "es", Language.ES_NI: "es", Language.ES_PA: "es", Language.ES_PE: "es", Language.ES_PR: "es", Language.ES_PY: "es", Language.ES_SV: "es", Language.ES_US: "es", Language.ES_UY: "es", Language.ES_VE: "es", # Persian Language.FA: "fa", Language.FA_IR: "fa", # Finnish Language.FI: "fi", Language.FI_FI: "fi", # French Language.FR: "fr", Language.FR_BE: "fr", Language.FR_CA: "fr", Language.FR_CH: "fr", Language.FR_FR: "fr", # Hindi Language.HI: "hi", Language.HI_IN: "hi", # Hungarian Language.HU: "hu", Language.HU_HU: "hu", # Indonesian Language.ID: "id", Language.ID_ID: "id", # Italian Language.IT: "it", Language.IT_IT: "it", # Japanese Language.JA: "ja", Language.JA_JP: "ja", # Korean Language.KO: "ko", Language.KO_KR: "ko", # Dutch Language.NL: "nl", Language.NL_BE: "nl", Language.NL_NL: "nl", # Polish Language.PL: "pl", Language.PL_PL: "pl", # Portuguese Language.PT: "pt", Language.PT_BR: "pt", Language.PT_PT: "pt", # Romanian Language.RO: "ro", Language.RO_RO: "ro", # Russian Language.RU: "ru", Language.RU_RU: "ru", # Slovak Language.SK: "sk", Language.SK_SK: "sk", # Swedish Language.SV: "sv", Language.SV_SE: "sv", # Thai Language.TH: "th", Language.TH_TH: "th", # Turkish Language.TR: "tr", Language.TR_TR: "tr", # Ukrainian Language.UK: "uk", Language.UK_UA: "uk", # Urdu Language.UR: "ur", Language.UR_IN: "ur", Language.UR_PK: "ur", # Vietnamese Language.VI: "vi", Language.VI_VN: "vi", # Chinese Language.ZH: "zh", Language.ZH_CN: "zh", Language.ZH_HK: "zh", Language.ZH_TW: "zh", } return language_map.get(language)
[docs] class WhisperSTTService(SegmentedSTTService): """Class to transcribe audio with a locally-downloaded Whisper model. This service uses Faster Whisper to perform speech-to-text transcription on audio segments. It supports multiple languages and various model sizes. Args: model: The Whisper model to use for transcription. Can be a Model enum or string. device: The device to run inference on ('cpu', 'cuda', or 'auto'). compute_type: The compute type for inference ('default', 'int8', 'int8_float16', etc.). no_speech_prob: Probability threshold for filtering out non-speech segments. language: The default language for transcription. **kwargs: Additional arguments passed to SegmentedSTTService. Attributes: _device: The device used for inference. _compute_type: The compute type for inference. _no_speech_prob: Threshold for non-speech filtering. _model: The loaded Whisper model instance. _settings: Dictionary containing service settings. """ def __init__( self, *, model: str | Model = Model.DISTIL_MEDIUM_EN, device: str = "auto", compute_type: str = "default", no_speech_prob: float = 0.4, language: Language = Language.EN, **kwargs, ): super().__init__(**kwargs) self._device: str = device self._compute_type = compute_type self.set_model_name(model if isinstance(model, str) else model.value) self._no_speech_prob = no_speech_prob self._model: Optional[WhisperModel] = None self._settings = { "language": language, "device": self._device, "compute_type": self._compute_type, "no_speech_prob": self._no_speech_prob, } self._load()
[docs] def can_generate_metrics(self) -> bool: """Indicates whether this service can generate metrics. Returns: bool: True, as this service supports metric generation. """ return True
[docs] def language_to_service_language(self, language: Language) -> Optional[str]: """Convert from pipecat Language to Whisper language code. Args: language: The Language enum value to convert. Returns: str or None: The corresponding Whisper language code, or None if not supported. """ return language_to_whisper_language(language)
[docs] async def set_language(self, language: Language): """Set the language for transcription. Args: language: The Language enum value to use for transcription. """ logger.info(f"Switching STT language to: [{language}]") self._settings["language"] = language
def _load(self): """Loads the Whisper model. Note: If this is the first time this model is being run, it will take time to download from the Hugging Face model hub. """ try: from faster_whisper import WhisperModel logger.debug("Loading Whisper model...") self._model = WhisperModel( self.model_name, device=self._device, compute_type=self._compute_type ) logger.debug("Loaded Whisper model") except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error("In order to use Whisper, you need to `pip install pipecat-ai[whisper]`.") self._model = None @traced_stt async def _handle_transcription( self, transcript: str, is_final: bool, language: Optional[Language] = None ): """Handle a transcription result with tracing.""" pass
[docs] async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Transcribes given audio using Whisper. Args: audio: Raw audio bytes in 16-bit PCM format. Yields: Frame: Either a TranscriptionFrame containing the transcribed text or an ErrorFrame if transcription fails. Note: The audio is expected to be 16-bit signed PCM data. The service will normalize it to float32 in the range [-1, 1]. """ if not self._model: logger.error(f"{self} error: Whisper model not available") yield ErrorFrame("Whisper model not available") return await self.start_processing_metrics() await self.start_ttfb_metrics() # Divide by 32768 because we have signed 16-bit data. audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0 whisper_lang = self.language_to_service_language(self._settings["language"]) segments, _ = await asyncio.to_thread( self._model.transcribe, audio_float, language=whisper_lang ) text: str = "" for segment in segments: if segment.no_speech_prob < self._no_speech_prob: text += f"{segment.text} " await self.stop_ttfb_metrics() await self.stop_processing_metrics() if text: await self._handle_transcription(text, True, self._settings["language"]) logger.debug(f"Transcription: [{text}]") yield TranscriptionFrame(text, "", time_now_iso8601(), self._settings["language"])
[docs] class WhisperSTTServiceMLX(WhisperSTTService): """Subclass of `WhisperSTTService` with MLX Whisper model support. This service uses MLX Whisper to perform speech-to-text transcription on audio segments. It's optimized for Apple Silicon and supports multiple languages and quantizations. Args: model: The MLX Whisper model to use for transcription. Can be an MLXModel enum or string. no_speech_prob: Probability threshold for filtering out non-speech segments. language: The default language for transcription. temperature: Temperature for sampling. Can be a float or tuple of floats. **kwargs: Additional arguments passed to SegmentedSTTService. Attributes: _no_speech_threshold: Threshold for non-speech filtering. _temperature: Temperature for sampling. _settings: Dictionary containing service settings. """ def __init__( self, *, model: str | MLXModel = MLXModel.TINY, no_speech_prob: float = 0.6, language: Language = Language.EN, temperature: float = 0.0, **kwargs, ): # Skip WhisperSTTService.__init__ and call its parent directly SegmentedSTTService.__init__(self, **kwargs) self.set_model_name(model if isinstance(model, str) else model.value) self._no_speech_prob = no_speech_prob self._temperature = temperature self._settings = { "language": language, "no_speech_prob": self._no_speech_prob, "temperature": self._temperature, "engine": "mlx", } # No need to call _load() as MLX Whisper loads models on demand @override def _load(self): """MLX Whisper loads models on demand, so this is a no-op.""" pass @traced_stt async def _handle_transcription( self, transcript: str, is_final: bool, language: Optional[Language] = None ): """Handle a transcription result with tracing.""" pass
[docs] @override async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Transcribes given audio using MLX Whisper. Args: audio: Raw audio bytes in 16-bit PCM format. Yields: Frame: Either a TranscriptionFrame containing the transcribed text or an ErrorFrame if transcription fails. Note: The audio is expected to be 16-bit signed PCM data. MLX Whisper will handle the conversion internally. """ try: import mlx_whisper await self.start_processing_metrics() await self.start_ttfb_metrics() # Divide by 32768 because we have signed 16-bit data. audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0 whisper_lang = self.language_to_service_language(self._settings["language"]) chunk = await asyncio.to_thread( mlx_whisper.transcribe, audio_float, path_or_hf_repo=self.model_name, temperature=self._temperature, language=whisper_lang, ) text: str = "" for segment in chunk.get("segments", []): # Drop likely hallucinations if segment.get("compression_ratio", None) == 0.5555555555555556: continue if segment.get("no_speech_prob", 0.0) < self._no_speech_prob: text += f"{segment.get('text', '')} " if len(text.strip()) == 0: text = None await self.stop_ttfb_metrics() await self.stop_processing_metrics() if text: await self._handle_transcription(text, True, self._settings["language"]) logger.debug(f"Transcription: [{text}]") yield TranscriptionFrame(text, "", time_now_iso8601(), self._settings["language"]) except Exception as e: logger.exception(f"MLX Whisper transcription error: {e}") yield ErrorFrame(f"MLX Whisper transcription error: {str(e)}")