#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import uuid
from typing import AsyncGenerator, Literal, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import InterruptibleTTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.tracing.service_decorators import traced_tts
try:
import ormsgpack
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`.")
raise Exception(f"Missing module: {e}")
# FishAudio supports various output formats
FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"]
[docs]
class FishAudioTTSService(InterruptibleTTSService):
def __init__(
self,
*,
api_key: str,
model: str, # This is the reference_id
output_format: FishAudioOutputFormat = "pcm",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
super().__init__(
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
params = params or FishAudioTTSService.InputParams()
self._api_key = api_key
self._base_url = "wss://api.fish.audio/v1/tts/live"
self._websocket = None
self._receive_task = None
self._request_id = None
self._started = False
self._settings = {
"sample_rate": 0,
"latency": params.latency,
"format": output_format,
"prosody": {
"speed": params.prosody_speed,
"volume": params.prosody_volume,
},
"reference_id": model,
}
self.set_model_name(model)
[docs]
def can_generate_metrics(self) -> bool:
return True
[docs]
async def set_model(self, model: str):
self._settings["reference_id"] = model
await super().set_model(model)
logger.info(f"Switching TTS model to: [{model}]")
[docs]
async def start(self, frame: StartFrame):
await super().start(frame)
self._settings["sample_rate"] = self.sample_rate
await self._connect()
[docs]
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
[docs]
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
try:
if self._websocket and self._websocket.open:
return
logger.debug("Connecting to Fish Audio")
headers = {"Authorization": f"Bearer {self._api_key}"}
self._websocket = await websockets.connect(self._base_url, extra_headers=headers)
# Send initial start message with ormsgpack
start_message = {"event": "start", "request": {"text": "", **self._settings}}
await self._websocket.send(ormsgpack.packb(start_message))
logger.debug("Sent start message to Fish Audio")
except Exception as e:
logger.error(f"Fish Audio initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from Fish Audio")
# Send stop event with ormsgpack
stop_message = {"event": "stop"}
await self._websocket.send(ormsgpack.packb(stop_message))
await self._websocket.close()
except Exception as e:
logger.error(f"Error closing websocket: {e}")
finally:
self._request_id = None
self._started = False
self._websocket = None
[docs]
async def flush_audio(self):
"""Flush any buffered audio by sending a flush event to Fish Audio."""
logger.trace(f"{self}: Flushing audio buffers")
if not self._websocket or self._websocket.closed:
return
flush_message = {"event": "flush"}
await self._get_websocket().send(ormsgpack.packb(flush_message))
def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
self._request_id = None
async def _receive_messages(self):
async for message in self._get_websocket():
try:
if isinstance(message, bytes):
msg = ormsgpack.unpackb(message)
if isinstance(msg, dict):
event = msg.get("event")
if event == "audio":
audio_data = msg.get("audio")
# Only process larger chunks to remove msgpack overhead
if audio_data and len(audio_data) > 1024:
frame = TTSAudioRawFrame(audio_data, self.sample_rate, 1)
await self.push_frame(frame)
await self.stop_ttfb_metrics()
continue
except Exception as e:
logger.error(f"Error processing message: {e}")
[docs]
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating Fish TTS: [{text}]")
try:
if not self._websocket or self._websocket.closed:
await self._connect()
if not self._request_id:
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame()
self._request_id = str(uuid.uuid4())
# Send the text
text_message = {
"event": "text",
"text": text,
}
try:
await self._get_websocket().send(ormsgpack.packb(text_message))
await self.start_tts_usage_metrics(text)
# Send flush event to force audio generation
flush_message = {"event": "flush"}
await self._get_websocket().send(ormsgpack.packb(flush_message))
except Exception as e:
logger.error(f"{self} error sending message: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
yield None
except Exception as e:
logger.error(f"Error generating TTS: {e}")
yield ErrorFrame(f"Error: {str(e)}")