#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import base64
import json
import uuid
from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import AudioContextWordTTSService, TTSService
from pipecat.transcriptions import language
from pipecat.transcriptions.language import Language
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
from pipecat.utils.tracing.service_decorators import traced_tts
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Rime, you need to `pip install pipecat-ai[rime]`.")
raise Exception(f"Missing module: {e}")
[docs]
def language_to_rime_language(language: Language) -> str:
"""Convert pipecat Language to Rime language code.
Args:
language: The pipecat Language enum value.
Returns:
str: Three-letter language code used by Rime (e.g., 'eng' for English).
"""
LANGUAGE_MAP = {
Language.DE: "ger",
Language.FR: "fra",
Language.EN: "eng",
Language.ES: "spa",
}
return LANGUAGE_MAP.get(language, "eng")
[docs]
class RimeTTSService(AudioContextWordTTSService):
"""Text-to-Speech service using Rime's websocket API.
Uses Rime's websocket JSON API to convert text to speech with word-level timing
information. Supports interruptions and maintains context across multiple messages
within a turn.
"""
def __init__(
self,
*,
api_key: str,
voice_id: str,
url: str = "wss://users.rime.ai/ws2",
model: str = "mistv2",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
text_aggregator: Optional[BaseTextAggregator] = None,
**kwargs,
):
"""Initialize Rime TTS service.
Args:
api_key: Rime API key for authentication.
voice_id: ID of the voice to use.
url: Rime websocket API endpoint.
model: Model ID to use for synthesis.
sample_rate: Audio sample rate in Hz.
params: Additional configuration parameters.
"""
# Initialize with parent class settings for proper frame handling
super().__init__(
aggregate_sentences=True,
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
text_aggregator=text_aggregator or SkipTagsAggregator([("spell(", ")")]),
**kwargs,
)
params = params or RimeTTSService.InputParams()
# Store service configuration
self._api_key = api_key
self._url = url
self._voice_id = voice_id
self._model = model
self._settings = {
"speaker": voice_id,
"modelId": model,
"audioFormat": "pcm",
"samplingRate": 0,
"lang": self.language_to_service_language(params.language)
if params.language
else "eng",
"speedAlpha": params.speed_alpha,
"reduceLatency": params.reduce_latency,
"pauseBetweenBrackets": json.dumps(params.pause_between_brackets),
"phonemizeBetweenBrackets": json.dumps(params.phonemize_between_brackets),
}
# State tracking
self._context_id = None # Tracks current turn
self._receive_task = None
self._cumulative_time = 0 # Accumulates time across messages
[docs]
def can_generate_metrics(self) -> bool:
return True
[docs]
def language_to_service_language(self, language: Language) -> str | None:
"""Convert pipecat language to Rime language code."""
return language_to_rime_language(language)
[docs]
async def set_model(self, model: str):
"""Update the TTS model."""
self._model = model
await super().set_model(model)
def _build_msg(self, text: str = "") -> dict:
"""Build JSON message for Rime API."""
return {"text": text, "contextId": self._context_id}
def _build_clear_msg(self) -> dict:
"""Build clear operation message."""
return {"operation": "clear"}
def _build_eos_msg(self) -> dict:
"""Build end-of-stream operation message."""
return {"operation": "eos"}
[docs]
async def start(self, frame: StartFrame):
"""Start the service and establish websocket connection."""
await super().start(frame)
self._settings["samplingRate"] = self.sample_rate
await self._connect()
[docs]
async def stop(self, frame: EndFrame):
"""Stop the service and close connection."""
await super().stop(frame)
await self._disconnect()
[docs]
async def cancel(self, frame: CancelFrame):
"""Cancel current operation and clean up."""
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
"""Establish websocket connection and start receive task."""
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Close websocket connection and clean up tasks."""
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
"""Connect to Rime websocket API with configured settings."""
try:
if self._websocket and self._websocket.open:
return
params = "&".join(f"{k}={v}" for k, v in self._settings.items())
url = f"{self._url}?{params}"
headers = {"Authorization": f"Bearer {self._api_key}"}
self._websocket = await websockets.connect(url, extra_headers=headers)
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
"""Close websocket connection and reset state."""
try:
await self.stop_all_metrics()
if self._websocket:
await self._websocket.send(json.dumps(self._build_eos_msg()))
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
finally:
self._context_id = None
self._websocket = None
def _get_websocket(self):
"""Get active websocket connection or raise exception."""
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
"""Handle interruption by clearing current context."""
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
if self._context_id:
await self._get_websocket().send(json.dumps(self._build_clear_msg()))
self._context_id = None
def _calculate_word_times(self, words: list, starts: list, ends: list) -> list:
"""Calculate word timing pairs with proper spacing and punctuation.
Args:
words: List of words from Rime.
starts: List of start times for each word.
ends: List of end times for each word.
Returns:
List of (word, timestamp) pairs with proper timing.
"""
word_pairs = []
for i, (word, start_time, _) in enumerate(zip(words, starts, ends)):
if not word.strip():
continue
# Adjust timing by adding cumulative time
adjusted_start = start_time + self._cumulative_time
# Handle punctuation by appending to previous word
is_punctuation = bool(word.strip(",.!?") == "")
if is_punctuation and word_pairs:
prev_word, prev_time = word_pairs[-1]
word_pairs[-1] = (prev_word + word, prev_time)
else:
word_pairs.append((word, adjusted_start))
return word_pairs
[docs]
async def flush_audio(self):
if not self._context_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
await self._get_websocket().send(json.dumps({"text": " "}))
self._context_id = None
async def _receive_messages(self):
"""Process incoming websocket messages."""
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or not self.audio_context_available(msg["contextId"]):
continue
if msg["type"] == "chunk":
# Process audio chunk
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self.sample_rate,
num_channels=1,
)
await self.append_to_audio_context(msg["contextId"], frame)
elif msg["type"] == "timestamps":
# Process word timing information
timestamps = msg.get("word_timestamps", {})
words = timestamps.get("words", [])
starts = timestamps.get("start", [])
ends = timestamps.get("end", [])
if words and starts:
# Calculate word timing pairs
word_pairs = self._calculate_word_times(words, starts, ends)
if word_pairs:
await self.add_word_timestamps(word_pairs)
self._cumulative_time = ends[-1] + self._cumulative_time
logger.debug(f"Updated cumulative time to: {self._cumulative_time}")
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
self._context_id = None
[docs]
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push frame and handle end-of-turn conditions."""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
if isinstance(frame, TTSStoppedFrame):
await self.add_word_timestamps([("Reset", 0)])
[docs]
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text.
Args:
text: The text to convert to speech.
Yields:
Frames containing audio data and timing information.
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
if not self._websocket or self._websocket.closed:
await self._connect()
try:
if not self._context_id:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._cumulative_time = 0
self._context_id = str(uuid.uuid4())
await self.create_audio_context(self._context_id)
msg = self._build_msg(text=text)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
[docs]
class RimeHttpTTSService(TTSService):
def __init__(
self,
*,
api_key: str,
voice_id: str,
aiohttp_session: aiohttp.ClientSession,
model: str = "mistv2",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or RimeHttpTTSService.InputParams()
self._api_key = api_key
self._session = aiohttp_session
self._base_url = "https://users.rime.ai/v1/rime-tts"
self._settings = {
"lang": self.language_to_service_language(params.language)
if params.language
else "eng",
"speedAlpha": params.speed_alpha,
"reduceLatency": params.reduce_latency,
"pauseBetweenBrackets": params.pause_between_brackets,
"phonemizeBetweenBrackets": params.phonemize_between_brackets,
}
self.set_voice(voice_id)
self.set_model_name(model)
if params.inline_speed_alpha:
self._settings["inlineSpeedAlpha"] = params.inline_speed_alpha
[docs]
def can_generate_metrics(self) -> bool:
return True
[docs]
def language_to_service_language(self, language: Language) -> str | None:
"""Convert pipecat language to Rime language code."""
return language_to_rime_language(language)
[docs]
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
headers = {
"Accept": "audio/pcm",
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
payload = self._settings.copy()
payload["text"] = text
payload["speaker"] = self._voice_id
payload["modelId"] = self._model_name
payload["samplingRate"] = self.sample_rate
# Arcana does not support PCM audio
if payload["modelId"] == "arcana":
headers["Accept"] = "audio/wav"
need_to_strip_wav_header = True
else:
need_to_strip_wav_header = False
try:
await self.start_ttfb_metrics()
async with self._session.post(
self._base_url, json=payload, headers=headers
) as response:
if response.status != 200:
error_message = f"Rime TTS error: HTTP {response.status}"
logger.error(error_message)
yield ErrorFrame(error=error_message)
return
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame()
CHUNK_SIZE = self.chunk_size
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
if need_to_strip_wav_header and chunk.startswith(b"RIFF"):
chunk = chunk[44:]
need_to_strip_wav_header = False
if len(chunk) > 0:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
yield frame
except Exception as e:
logger.exception(f"Error generating TTS: {e}")
yield ErrorFrame(error=f"Rime TTS error: {str(e)}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()