#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import json
import time
from typing import Any, List, Literal, Optional, Union
from loguru import logger
from pydantic import BaseModel, TypeAdapter
from pipecat.utils.base_object import BaseObject
try:
from aiortc import (
MediaStreamTrack,
RTCConfiguration,
RTCIceServer,
RTCPeerConnection,
RTCSessionDescription,
)
from aiortc.rtcrtpreceiver import RemoteStreamTrack
from av.frame import Frame
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use the SmallWebRTC, you need to `pip install pipecat-ai[webrtc]`.")
raise Exception(f"Missing module: {e}")
SIGNALLING_TYPE = "signalling"
AUDIO_TRANSCEIVER_INDEX = 0
VIDEO_TRANSCEIVER_INDEX = 1
[docs]
class TrackStatusMessage(BaseModel):
type: Literal["trackStatus"]
receiver_index: int
enabled: bool
[docs]
class RenegotiateMessage(BaseModel):
type: Literal["renegotiate"] = "renegotiate"
[docs]
class PeerLeftMessage(BaseModel):
type: Literal["peerLeft"] = "peerLeft"
[docs]
class SignallingMessage:
Inbound = Union[TrackStatusMessage] # in case we need to add new messages in the future
outbound = Union[RenegotiateMessage]
[docs]
class SmallWebRTCTrack:
def __init__(self, track: MediaStreamTrack):
self._track = track
self._enabled = True
[docs]
def set_enabled(self, enabled: bool) -> None:
self._enabled = enabled
[docs]
def is_enabled(self) -> bool:
return self._enabled
[docs]
async def discard_old_frames(self):
remote_track = self._track
if isinstance(remote_track, RemoteStreamTrack):
if not hasattr(remote_track, "_queue") or not isinstance(
remote_track._queue, asyncio.Queue
):
print("Warning: _queue does not exist or has changed in aiortc.")
return
logger.debug("Discarding old frames")
while not remote_track._queue.empty():
remote_track._queue.get_nowait() # Remove the oldest frame
remote_track._queue.task_done()
[docs]
async def recv(self) -> Optional[Frame]:
if not self._enabled:
return None
return await self._track.recv()
def __getattr__(self, name):
# Forward other attribute/method calls to the underlying track
return getattr(self._track, name)
# Alias so we don't need to expose RTCIceServer
IceServer = RTCIceServer
[docs]
class SmallWebRTCConnection(BaseObject):
def __init__(self, ice_servers: Optional[Union[List[str], List[IceServer]]] = None):
super().__init__()
if not ice_servers:
self.ice_servers: List[IceServer] = []
elif all(isinstance(s, IceServer) for s in ice_servers):
self.ice_servers = ice_servers
elif all(isinstance(s, str) for s in ice_servers):
self.ice_servers = [IceServer(urls=s) for s in ice_servers]
else:
raise TypeError("ice_servers must be either List[str] or List[RTCIceServer]")
self._connect_invoked = False
self._track_map = {}
self._track_getters = {
AUDIO_TRANSCEIVER_INDEX: self.audio_input_track,
VIDEO_TRANSCEIVER_INDEX: self.video_input_track,
}
self._initialize()
# Register supported handlers. The user will only be able to register
# these handlers.
self._register_event_handler("app-message")
self._register_event_handler("track-started")
self._register_event_handler("track-ended")
# connection states
self._register_event_handler("connecting")
self._register_event_handler("connected")
self._register_event_handler("disconnected")
self._register_event_handler("closed")
self._register_event_handler("failed")
self._register_event_handler("new")
@property
def pc(self) -> RTCPeerConnection:
return self._pc
@property
def pc_id(self) -> str:
return self._pc_id
def _initialize(self):
logger.debug("Initializing new peer connection")
rtc_config = RTCConfiguration(iceServers=self.ice_servers)
self._answer: Optional[RTCSessionDescription] = None
self._pc = RTCPeerConnection(rtc_config)
self._pc_id = self.name
self._setup_listeners()
self._data_channel = None
self._renegotiation_in_progress = False
self._last_received_time = None
self._message_queue = []
self._pending_app_messages = []
def _setup_listeners(self):
@self._pc.on("datachannel")
def on_datachannel(channel):
self._data_channel = channel
# Flush queued messages once the data channel is open
@channel.on("open")
async def on_open():
logger.debug("Data channel is open, flushing queued messages")
while self._message_queue:
message = self._message_queue.pop(0)
self._data_channel.send(message)
@channel.on("message")
async def on_message(message):
try:
# aiortc does not provide any way so we can be aware when we are disconnected,
# so we are using this keep alive message as a way to implement that
if isinstance(message, str) and message.startswith("ping"):
self._last_received_time = time.time()
else:
json_message = json.loads(message)
if json_message["type"] == SIGNALLING_TYPE and json_message.get("message"):
self._handle_signalling_message(json_message["message"])
else:
if self.is_connected():
await self._call_event_handler("app-message", json_message)
else:
logger.debug("Client not connected. Queuing app-message.")
self._pending_app_messages.append(json_message)
except Exception as e:
logger.exception(f"Error parsing JSON message {message}, {e}")
# Despite the fact that aiortc provides this listener, they don't have a status for "disconnected"
# So, in case we loose connection, this event will not be triggered
@self._pc.on("connectionstatechange")
async def on_connectionstatechange():
await self._handle_new_connection_state()
# Despite the fact that aiortc provides this listener, they don't have a status for "disconnected"
# So, in case we loose connection, this event will not be triggered
@self._pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
logger.debug(
f"ICE connection state is {self._pc.iceConnectionState}, connection is {self._pc.connectionState}"
)
@self._pc.on("icegatheringstatechange")
async def on_icegatheringstatechange():
logger.debug(f"ICE gathering state is {self._pc.iceGatheringState}")
@self._pc.on("track")
async def on_track(track):
logger.debug(f"Track {track.kind} received")
await self._call_event_handler("track-started", track)
@track.on("ended")
async def on_ended():
logger.debug(f"Track {track.kind} ended")
await self._call_event_handler("track-ended", track)
async def _create_answer(self, sdp: str, type: str):
offer = RTCSessionDescription(sdp=sdp, type=type)
await self._pc.setRemoteDescription(offer)
# For some reason, aiortc is not respecting the SDP for the transceivers to be sendrcv
# so we are basically forcing it to act this way
self.force_transceivers_to_send_recv()
# this answer does not contain the ice candidates, which will be gathered later, after the setLocalDescription
logger.debug(f"Creating answer")
local_answer = await self._pc.createAnswer()
await self._pc.setLocalDescription(local_answer)
logger.debug(f"Setting the answer after the local description is created")
self._answer = self._pc.localDescription
[docs]
async def initialize(self, sdp: str, type: str):
await self._create_answer(sdp, type)
[docs]
async def connect(self):
self._connect_invoked = True
# If we already connected, trigger again the connected event
if self.is_connected():
await self._call_event_handler("connected")
logger.debug("Flushing pending app-messages")
for message in self._pending_app_messages:
await self._call_event_handler("app-message", message)
# We are renegotiating here, because likely we have loose the first video frames
# and aiortc does not handle that pretty well.
video_input_track = self.video_input_track()
if video_input_track:
await self.video_input_track().discard_old_frames()
self.ask_to_renegotiate()
[docs]
async def renegotiate(self, sdp: str, type: str, restart_pc: bool = False):
logger.debug(f"Renegotiating {self._pc_id}")
if restart_pc:
await self._call_event_handler("disconnected")
logger.debug("Closing old peer connection")
# removing the listeners to prevent the bot from closing
self._pc.remove_all_listeners()
await self._close()
# we are initializing a new peer connection in this case.
self._initialize()
await self._create_answer(sdp, type)
# Maybe we should refactor to receive a message from the client side when the renegotiation is completed.
# or look at the peer connection listeners
# but this is good enough for now for testing.
async def delayed_task():
await asyncio.sleep(2)
self._renegotiation_in_progress = False
asyncio.create_task(delayed_task())
[docs]
def force_transceivers_to_send_recv(self):
for transceiver in self._pc.getTransceivers():
transceiver.direction = "sendrecv"
# logger.debug(
# f"Transceiver: {transceiver}, Mid: {transceiver.mid}, Direction: {transceiver.direction}"
# )
# logger.debug(f"Sender track: {transceiver.sender.track}")
[docs]
def replace_audio_track(self, track):
logger.debug(f"Replacing audio track {track.kind}")
# Transceivers always appear in creation-order for both peers
# For now we are only considering that we are going to have 02 transceivers,
# one for audio and one for video
transceivers = self._pc.getTransceivers()
if len(transceivers) > 0 and transceivers[0].sender:
transceivers[0].sender.replaceTrack(track)
else:
logger.warning("Audio transceiver not found. Cannot replace audio track.")
[docs]
def replace_video_track(self, track):
logger.debug(f"Replacing video track {track.kind}")
# Transceivers always appear in creation-order for both peers
# For now we are only considering that we are going to have 02 transceivers,
# one for audio and one for video
transceivers = self._pc.getTransceivers()
if len(transceivers) > 1 and transceivers[1].sender:
transceivers[1].sender.replaceTrack(track)
else:
logger.warning("Video transceiver not found. Cannot replace video track.")
[docs]
async def disconnect(self):
self.send_app_message({"type": SIGNALLING_TYPE, "message": PeerLeftMessage().model_dump()})
await self._close()
async def _close(self):
if self._pc:
await self._pc.close()
self._message_queue.clear()
self._pending_app_messages.clear()
self._track_map = {}
[docs]
def get_answer(self):
if not self._answer:
return None
return {
"sdp": self._answer.sdp,
"type": self._answer.type,
"pc_id": self._pc_id,
}
async def _handle_new_connection_state(self):
state = self._pc.connectionState
if state == "connected" and not self._connect_invoked:
# We are going to wait until the pipeline is ready before triggering the event
return
logger.debug(f"Connection state changed to: {state}")
await self._call_event_handler(state)
if state == "failed":
logger.warning("Connection failed, closing peer connection.")
await self._close()
# Despite the fact that aiortc provides this listener, they don't have a status for "disconnected"
# So, there is no advantage in looking at self._pc.connectionState
# That is why we are trying to keep our own state
[docs]
def is_connected(self):
# If the small webrtc transport has never invoked to connect
# we are acting like if we are not connected
if not self._connect_invoked:
return False
if self._last_received_time is None:
# if we have never received a message, it is probably because the client has not created a data channel
# so we are going to trust aiortc in this case
return self._pc.connectionState == "connected"
# Checks if the last received ping was within the last 3 seconds.
return (time.time() - self._last_received_time) < 3
[docs]
def send_app_message(self, message: Any):
json_message = json.dumps(message)
if self._data_channel and self._data_channel.readyState == "open":
self._data_channel.send(json_message)
else:
logger.debug("Data channel not ready, queuing message")
self._message_queue.append(json_message)
[docs]
def ask_to_renegotiate(self):
if self._renegotiation_in_progress:
return
self._renegotiation_in_progress = True
self.send_app_message(
{"type": SIGNALLING_TYPE, "message": RenegotiateMessage().model_dump()}
)
def _handle_signalling_message(self, message):
logger.debug(f"Signalling message received: {message}")
inbound_adapter = TypeAdapter(SignallingMessage.Inbound)
signalling_message = inbound_adapter.validate_python(message)
match signalling_message:
case TrackStatusMessage():
track = (
self._track_getters.get(signalling_message.receiver_index) or (lambda: None)
)()
if track:
track.set_enabled(signalling_message.enabled)