#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import base64
import json
import warnings
from typing import Any, AsyncGenerator, Dict, Optional
import aiohttp
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
TranslationFrame,
)
from pipecat.services.gladia.config import GladiaInputParams
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Gladia, you need to `pip install pipecat-ai[gladia]`.")
raise Exception(f"Missing module: {e}")
[docs]
def language_to_gladia_language(language: Language) -> Optional[str]:
"""Convert a Language enum to Gladia's language code format.
Args:
language: The Language enum value to convert
Returns:
The Gladia language code string or None if not supported
"""
BASE_LANGUAGES = {
Language.AF: "af",
Language.AM: "am",
Language.AR: "ar",
Language.AS: "as",
Language.AZ: "az",
Language.BA: "ba",
Language.BE: "be",
Language.BG: "bg",
Language.BN: "bn",
Language.BO: "bo",
Language.BR: "br",
Language.BS: "bs",
Language.CA: "ca",
Language.CS: "cs",
Language.CY: "cy",
Language.DA: "da",
Language.DE: "de",
Language.EL: "el",
Language.EN: "en",
Language.ES: "es",
Language.ET: "et",
Language.EU: "eu",
Language.FA: "fa",
Language.FI: "fi",
Language.FO: "fo",
Language.FR: "fr",
Language.GL: "gl",
Language.GU: "gu",
Language.HA: "ha",
Language.HAW: "haw",
Language.HE: "he",
Language.HI: "hi",
Language.HR: "hr",
Language.HT: "ht",
Language.HU: "hu",
Language.HY: "hy",
Language.ID: "id",
Language.IS: "is",
Language.IT: "it",
Language.JA: "ja",
Language.JV: "jv",
Language.KA: "ka",
Language.KK: "kk",
Language.KM: "km",
Language.KN: "kn",
Language.KO: "ko",
Language.LA: "la",
Language.LB: "lb",
Language.LN: "ln",
Language.LO: "lo",
Language.LT: "lt",
Language.LV: "lv",
Language.MG: "mg",
Language.MI: "mi",
Language.MK: "mk",
Language.ML: "ml",
Language.MN: "mn",
Language.MR: "mr",
Language.MS: "ms",
Language.MT: "mt",
Language.MY_MR: "mymr",
Language.NE: "ne",
Language.NL: "nl",
Language.NN: "nn",
Language.NO: "no",
Language.OC: "oc",
Language.PA: "pa",
Language.PL: "pl",
Language.PS: "ps",
Language.PT: "pt",
Language.RO: "ro",
Language.RU: "ru",
Language.SA: "sa",
Language.SD: "sd",
Language.SI: "si",
Language.SK: "sk",
Language.SL: "sl",
Language.SN: "sn",
Language.SO: "so",
Language.SQ: "sq",
Language.SR: "sr",
Language.SU: "su",
Language.SV: "sv",
Language.SW: "sw",
Language.TA: "ta",
Language.TE: "te",
Language.TG: "tg",
Language.TH: "th",
Language.TK: "tk",
Language.TL: "tl",
Language.TR: "tr",
Language.TT: "tt",
Language.UK: "uk",
Language.UR: "ur",
Language.UZ: "uz",
Language.VI: "vi",
Language.YI: "yi",
Language.YO: "yo",
Language.ZH: "zh",
}
result = BASE_LANGUAGES.get(language)
# If not found in base languages, try to find the base language from a variant
if not result:
# Convert enum value to string and get the base language part (e.g. es-ES -> es)
lang_str = str(language.value)
base_code = lang_str.split("-")[0].lower()
# Look up the base code in our supported languages
result = base_code if base_code in BASE_LANGUAGES.values() else None
return result
# Deprecation warning for nested InputParams
class _InputParamsDescriptor:
"""Descriptor for backward compatibility with deprecation warning."""
def __get__(self, obj, objtype=None):
warnings.warn(
"GladiaSTTService.InputParams is deprecated and will be removed in a future version. "
"Import and use GladiaInputParams directly instead.",
DeprecationWarning,
stacklevel=2,
)
return GladiaInputParams
[docs]
class GladiaSTTService(STTService):
"""Speech-to-Text service using Gladia's API.
This service connects to Gladia's WebSocket API for real-time transcription
with support for multiple languages, custom vocabulary, and various processing options.
For complete API documentation, see: https://docs.gladia.io/api-reference/v2/live/init
"""
# Maintain backward compatibility
InputParams = _InputParamsDescriptor()
def __init__(
self,
*,
api_key: str,
url: str = "https://api.gladia.io/v2/live",
confidence: float = 0.5,
sample_rate: Optional[int] = None,
model: str = "solaria-1",
params: Optional[GladiaInputParams] = None,
max_reconnection_attempts: int = 5,
reconnection_delay: float = 1.0,
max_buffer_size: int = 1024 * 1024 * 20, # 20MB default buffer
**kwargs,
):
"""Initialize the Gladia STT service.
Args:
api_key: Gladia API key
url: Gladia API URL
confidence: Minimum confidence threshold for transcriptions
sample_rate: Audio sample rate in Hz
model: Model to use ("solaria-1")
params: Additional configuration parameters
max_reconnection_attempts: Maximum number of reconnection attempts
reconnection_delay: Initial delay between reconnection attempts (exponential backoff)
max_buffer_size: Maximum size of audio buffer in bytes
**kwargs: Additional arguments passed to the STTService
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or GladiaInputParams()
# Warn about deprecated language parameter if it's used
if params.language is not None:
warnings.warn(
"The 'language' parameter is deprecated and will be removed in a future version. "
"Use 'language_config' instead.",
DeprecationWarning,
stacklevel=2,
)
self._api_key = api_key
self._url = url
self.set_model_name(model)
self._confidence = confidence
self._params = params
self._websocket = None
self._receive_task = None
self._keepalive_task = None
self._settings = {}
# Reconnection settings
self._max_reconnection_attempts = max_reconnection_attempts
self._reconnection_delay = reconnection_delay
self._reconnection_attempts = 0
self._session_url = None
self._connection_active = False
# Audio buffer management
self._audio_buffer = bytearray()
self._bytes_sent = 0
self._max_buffer_size = max_buffer_size
self._buffer_lock = asyncio.Lock()
# Connection management
self._connection_task = None
self._should_reconnect = True
[docs]
def can_generate_metrics(self) -> bool:
return True
[docs]
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert pipecat Language enum to Gladia's language code."""
return language_to_gladia_language(language)
def _prepare_settings(self) -> Dict[str, Any]:
settings = {
"encoding": self._params.encoding or "wav/pcm",
"bit_depth": self._params.bit_depth or 16,
"sample_rate": self.sample_rate,
"channels": self._params.channels or 1,
"model": self._model_name,
}
# Add custom_metadata if provided
if self._params.custom_metadata:
settings["custom_metadata"] = self._params.custom_metadata
# Add endpointing parameters if provided
if self._params.endpointing is not None:
settings["endpointing"] = self._params.endpointing
if self._params.maximum_duration_without_endpointing is not None:
settings["maximum_duration_without_endpointing"] = (
self._params.maximum_duration_without_endpointing
)
# Add language configuration (prioritize language_config over deprecated language)
if self._params.language_config:
settings["language_config"] = self._params.language_config.model_dump(exclude_none=True)
elif self._params.language: # Backward compatibility for deprecated parameter
language_code = self.language_to_service_language(self._params.language)
if language_code:
settings["language_config"] = {
"languages": [language_code],
"code_switching": False,
}
# Add pre_processing configuration if provided
if self._params.pre_processing:
settings["pre_processing"] = self._params.pre_processing.model_dump(exclude_none=True)
# Add realtime_processing configuration if provided
if self._params.realtime_processing:
settings["realtime_processing"] = self._params.realtime_processing.model_dump(
exclude_none=True
)
# Add messages_config if provided
if self._params.messages_config:
settings["messages_config"] = self._params.messages_config.model_dump(exclude_none=True)
# Store settings for tracing
self._settings = settings
return settings
[docs]
async def start(self, frame: StartFrame):
"""Start the Gladia STT websocket connection."""
await super().start(frame)
if self._connection_task:
return
self._should_reconnect = True
self._connection_task = self.create_task(self._connection_handler())
[docs]
async def stop(self, frame: EndFrame):
"""Stop the Gladia STT websocket connection."""
await super().stop(frame)
self._should_reconnect = False
await self._send_stop_recording()
if self._connection_task:
await self.cancel_task(self._connection_task)
self._connection_task = None
await self._cleanup_connection()
[docs]
async def cancel(self, frame: CancelFrame):
"""Cancel the Gladia STT websocket connection."""
await super().cancel(frame)
self._should_reconnect = False
if self._connection_task:
await self.cancel_task(self._connection_task)
self._connection_task = None
await self._cleanup_connection()
[docs]
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Run speech-to-text on audio data."""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
# Add audio to buffer
async with self._buffer_lock:
self._audio_buffer.extend(audio)
# Trim buffer if it exceeds max size
if len(self._audio_buffer) > self._max_buffer_size:
trim_size = len(self._audio_buffer) - self._max_buffer_size
self._audio_buffer = self._audio_buffer[trim_size:]
self._bytes_sent = max(0, self._bytes_sent - trim_size)
logger.warning(f"Audio buffer exceeded max size, trimmed {trim_size} bytes")
# Send audio if connected
if self._connection_active and self._websocket and not self._websocket.closed:
try:
await self._send_audio(audio)
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"Websocket closed while sending audio chunk: {e}")
self._connection_active = False
yield None
async def _connection_handler(self):
"""Handle WebSocket connection with automatic reconnection."""
while self._should_reconnect:
try:
# Initialize session if needed
if not self._session_url:
settings = self._prepare_settings()
response = await self._setup_gladia(settings)
self._session_url = response["url"]
self._reconnection_attempts = 0
# Connect with automatic reconnection
async with websockets.connect(self._session_url) as websocket:
try:
self._websocket = websocket
self._connection_active = True
logger.info("Connected to Gladia WebSocket")
# Send buffered audio if any
await self._send_buffered_audio()
# Start tasks
self._receive_task = asyncio.create_task(self._receive_task_handler())
self._keepalive_task = asyncio.create_task(self._keepalive_task_handler())
# Wait for tasks to complete
await asyncio.gather(self._receive_task, self._keepalive_task)
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"WebSocket connection closed: {e}")
self._connection_active = False
# Clean up tasks
if self._receive_task:
self._receive_task.cancel()
if self._keepalive_task:
self._keepalive_task.cancel()
# Attempt reconnect using helper
if not await self._maybe_reconnect():
break
except Exception as e:
logger.error(f"Error in connection handler: {e}")
self._connection_active = False
if not self._should_reconnect:
break
# Reset session URL to get a new one
self._session_url = None
await asyncio.sleep(self._reconnection_delay)
async def _cleanup_connection(self):
"""Clean up connection resources."""
self._connection_active = False
if self._keepalive_task:
await self.cancel_task(self._keepalive_task)
self._keepalive_task = None
if self._websocket:
await self._websocket.close()
self._websocket = None
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
async def _setup_gladia(self, settings: Dict[str, Any]):
async with aiohttp.ClientSession() as session:
async with session.post(
self._url,
headers={"X-Gladia-Key": self._api_key, "Content-Type": "application/json"},
json=settings,
) as response:
if response.ok:
return await response.json()
else:
error_text = await response.text()
logger.error(
f"Gladia error: {response.status}: {error_text or response.reason}"
)
raise Exception(
f"Failed to initialize Gladia session: {response.status} - {error_text}"
)
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[str] = None
):
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
async def _send_audio(self, audio: bytes):
"""Send audio chunk with proper message format."""
if self._websocket and not self._websocket.closed:
data = base64.b64encode(audio).decode("utf-8")
message = {"type": "audio_chunk", "data": {"chunk": data}}
await self._websocket.send(json.dumps(message))
async def _send_buffered_audio(self):
"""Send any buffered audio after reconnection."""
async with self._buffer_lock:
if self._audio_buffer:
logger.info(f"Sending {len(self._audio_buffer)} bytes of buffered audio")
await self._send_audio(bytes(self._audio_buffer))
async def _send_stop_recording(self):
if self._websocket and not self._websocket.closed:
await self._websocket.send(json.dumps({"type": "stop_recording"}))
async def _keepalive_task_handler(self):
"""Send periodic empty audio chunks to keep the connection alive."""
try:
while self._connection_active:
# Send keepalive every 20 seconds (Gladia times out after 30 seconds)
await asyncio.sleep(20)
if self._websocket and not self._websocket.closed:
# Send an empty audio chunk as keepalive
empty_audio = b""
await self._send_audio(empty_audio)
else:
logger.debug("Websocket closed, stopping keepalive")
break
except websockets.exceptions.ConnectionClosed:
logger.debug("Connection closed during keepalive")
except Exception as e:
logger.error(f"Error in Gladia keepalive task: {e}")
async def _receive_task_handler(self):
try:
async for message in self._websocket:
self.start_watchdog()
content = json.loads(message)
# Handle audio chunk acknowledgments
if content["type"] == "audio_chunk" and content.get("acknowledged"):
byte_range = content["data"]["byte_range"]
async with self._buffer_lock:
# Update bytes sent and trim acknowledged data from buffer
end_byte = byte_range[1]
if end_byte > self._bytes_sent:
trim_size = end_byte - self._bytes_sent
self._audio_buffer = self._audio_buffer[trim_size:]
self._bytes_sent = end_byte
elif content["type"] == "transcript":
utterance = content["data"]["utterance"]
confidence = utterance.get("confidence", 0)
language = utterance["language"]
transcript = utterance["text"]
is_final = content["data"]["is_final"]
if confidence >= self._confidence:
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
"",
time_now_iso8601(),
language,
result=content,
)
)
await self._handle_transcription(
transcript=transcript,
is_final=is_final,
language=language,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
"",
time_now_iso8601(),
language,
result=content,
)
)
elif content["type"] == "translation":
translated_utterance = content["data"]["translated_utterance"]
original_language = content["data"]["original_language"]
translated_language = translated_utterance["language"]
confidence = translated_utterance.get("confidence", 0)
translation = translated_utterance["text"]
if translated_language != original_language and confidence >= self._confidence:
await self.push_frame(
TranslationFrame(
translation, "", time_now_iso8601(), translated_language
)
)
self.reset_watchdog()
except websockets.exceptions.ConnectionClosed:
# Expected when closing the connection
pass
except Exception as e:
logger.error(f"Error in Gladia WebSocket handler: {e}")
finally:
self.reset_watchdog()
async def _maybe_reconnect(self) -> bool:
"""Handle exponential backoff reconnection logic."""
if not self._should_reconnect:
return False
self._reconnection_attempts += 1
if self._reconnection_attempts > self._max_reconnection_attempts:
logger.error(f"Max reconnection attempts ({self._max_reconnection_attempts}) reached")
self._should_reconnect = False
return False
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
logger.info(
f"Reconnecting in {delay} seconds (attempt {self._reconnection_attempts}/{self._max_reconnection_attempts})"
)
await asyncio.sleep(delay)
return True