#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import io
import time
import wave
from typing import Awaitable, Callable, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InputAudioRawFrame,
OutputAudioRawFrame,
StartFrame,
StartInterruptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use websockets, you need to `pip install pipecat-ai[websocket]`.")
raise Exception(f"Missing module: {e}")
[docs]
class WebsocketServerParams(TransportParams):
add_wav_header: bool = False
serializer: Optional[FrameSerializer] = None
session_timeout: Optional[int] = None
[docs]
class WebsocketServerCallbacks(BaseModel):
on_client_connected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]]
on_client_disconnected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]]
on_session_timeout: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]]
on_websocket_ready: Callable[[], Awaitable[None]]
[docs]
class WebsocketServerOutputTransport(BaseOutputTransport):
def __init__(self, transport: BaseTransport, params: WebsocketServerParams, **kwargs):
super().__init__(params, **kwargs)
self._transport = transport
self._params = params
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
# write_audio_frame() is called quickly, as soon as we get audio
# (e.g. from the TTS), and since this is just a network connection we
# would be sending it to quickly. Instead, we want to block to emulate
# an audio device, this is what the send interval is. It will be
# computed on StartFrame.
self._send_interval = 0
self._next_send_time = 0
# Whether we have seen a StartFrame already.
self._initialized = False
[docs]
async def set_client_connection(self, websocket: Optional[websockets.WebSocketServerProtocol]):
if self._websocket:
await self._websocket.close()
logger.warning("Only one client allowed, using new connection")
self._websocket = websocket
[docs]
async def start(self, frame: StartFrame):
await super().start(frame)
if self._initialized:
return
self._initialized = True
if self._params.serializer:
await self._params.serializer.setup(frame)
self._send_interval = (self.audio_chunk_size / self.sample_rate) / 2
await self.set_transport_ready(frame)
[docs]
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._write_frame(frame)
[docs]
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._write_frame(frame)
[docs]
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
[docs]
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._write_frame(frame)
self._next_send_time = 0
[docs]
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._write_frame(frame)
[docs]
async def write_audio_frame(self, frame: OutputAudioRawFrame):
if not self._websocket:
# Simulate audio playback with a sleep.
await self._write_audio_sleep()
return
frame = OutputAudioRawFrame(
audio=frame.audio,
sample_rate=self.sample_rate,
num_channels=self._params.audio_out_channels,
)
if self._params.add_wav_header:
with io.BytesIO() as buffer:
with wave.open(buffer, "wb") as wf:
wf.setsampwidth(2)
wf.setnchannels(frame.num_channels)
wf.setframerate(frame.sample_rate)
wf.writeframes(frame.audio)
wav_frame = OutputAudioRawFrame(
buffer.getvalue(),
sample_rate=frame.sample_rate,
num_channels=frame.num_channels,
)
frame = wav_frame
await self._write_frame(frame)
# Simulate audio playback with a sleep.
await self._write_audio_sleep()
async def _write_frame(self, frame: Frame):
if not self._params.serializer:
return
try:
payload = await self._params.serializer.serialize(frame)
if payload and self._websocket:
await self._websocket.send(payload)
except Exception as e:
logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})")
async def _write_audio_sleep(self):
# Simulate a clock.
current_time = time.monotonic()
sleep_duration = max(0, self._next_send_time - current_time)
await asyncio.sleep(sleep_duration)
if sleep_duration == 0:
self._next_send_time = time.monotonic() + self._send_interval
else:
self._next_send_time += self._send_interval
[docs]
class WebsocketServerTransport(BaseTransport):
def __init__(
self,
params: WebsocketServerParams,
host: str = "localhost",
port: int = 8765,
input_name: Optional[str] = None,
output_name: Optional[str] = None,
):
super().__init__(input_name=input_name, output_name=output_name)
self._host = host
self._port = port
self._params = params
self._callbacks = WebsocketServerCallbacks(
on_client_connected=self._on_client_connected,
on_client_disconnected=self._on_client_disconnected,
on_session_timeout=self._on_session_timeout,
on_websocket_ready=self._on_websocket_ready,
)
self._input: Optional[WebsocketServerInputTransport] = None
self._output: Optional[WebsocketServerOutputTransport] = None
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
# Register supported handlers. The user will only be able to register
# these handlers.
self._register_event_handler("on_client_connected")
self._register_event_handler("on_client_disconnected")
self._register_event_handler("on_session_timeout")
self._register_event_handler("on_websocket_ready")
[docs]
def output(self) -> WebsocketServerOutputTransport:
if not self._output:
self._output = WebsocketServerOutputTransport(
self, self._params, name=self._output_name
)
return self._output
async def _on_client_connected(self, websocket):
if self._output:
await self._output.set_client_connection(websocket)
await self._call_event_handler("on_client_connected", websocket)
else:
logger.error("A WebsocketServerTransport output is missing in the pipeline")
async def _on_client_disconnected(self, websocket):
if self._output:
await self._output.set_client_connection(None)
await self._call_event_handler("on_client_disconnected", websocket)
else:
logger.error("A WebsocketServerTransport output is missing in the pipeline")
async def _on_session_timeout(self, websocket):
await self._call_event_handler("on_session_timeout", websocket)
async def _on_websocket_ready(self):
await self._call_event_handler("on_websocket_ready")