Source code for pipecat.processors.aggregators.llm_response

#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Set

from loguru import logger

from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
from pipecat.frames.frames import (
    BotInterruptionFrame,
    BotStartedSpeakingFrame,
    BotStoppedSpeakingFrame,
    CancelFrame,
    EmulateUserStartedSpeakingFrame,
    EmulateUserStoppedSpeakingFrame,
    EndFrame,
    Frame,
    FunctionCallCancelFrame,
    FunctionCallInProgressFrame,
    FunctionCallResultFrame,
    FunctionCallsStartedFrame,
    InputAudioRawFrame,
    InterimTranscriptionFrame,
    LLMFullResponseEndFrame,
    LLMFullResponseStartFrame,
    LLMMessagesAppendFrame,
    LLMMessagesFrame,
    LLMMessagesUpdateFrame,
    LLMSetToolChoiceFrame,
    LLMSetToolsFrame,
    LLMTextFrame,
    OpenAILLMContextAssistantTimestampFrame,
    StartFrame,
    StartInterruptionFrame,
    TextFrame,
    TranscriptionFrame,
    UserImageRawFrame,
    UserStartedSpeakingFrame,
    UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.openai_llm_context import (
    OpenAILLMContext,
    OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.time import time_now_iso8601


[docs] @dataclass class LLMUserAggregatorParams: aggregation_timeout: float = 0.5
[docs] @dataclass class LLMAssistantAggregatorParams: expect_stripped_words: bool = True
[docs] class LLMFullResponseAggregator(FrameProcessor): """This is an LLM aggregator that aggregates a full LLM completion. It aggregates LLM text frames (tokens) received between `LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`. Every full completion is returned via the "on_completion" event handler: @aggregator.event_handler("on_completion") async def on_completion( aggregator: LLMFullResponseAggregator, completion: str, completed: bool, ) """ def __init__(self, **kwargs): super().__init__(**kwargs) self._aggregation = "" self._started = False self._register_event_handler("on_completion")
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) if isinstance(frame, StartInterruptionFrame): await self._call_event_handler("on_completion", self._aggregation, False) self._aggregation = "" self._started = False elif isinstance(frame, LLMFullResponseStartFrame): await self._handle_llm_start(frame) elif isinstance(frame, LLMFullResponseEndFrame): await self._handle_llm_end(frame) elif isinstance(frame, LLMTextFrame): await self._handle_llm_text(frame) await self.push_frame(frame, direction)
async def _handle_llm_start(self, _: LLMFullResponseStartFrame): self._started = True async def _handle_llm_end(self, _: LLMFullResponseEndFrame): await self._call_event_handler("on_completion", self._aggregation, True) self._started = False self._aggregation = "" async def _handle_llm_text(self, frame: TextFrame): if not self._started: return self._aggregation += frame.text
[docs] class BaseLLMResponseAggregator(FrameProcessor): """This is the base class for all LLM response aggregators. These aggregators process incoming frames and aggregate content until they are ready to push the aggregation. In the case of a user, an aggregation might be a full transcription received from the STT service. The LLM response aggregators also keep a store (e.g. a message list or an LLM context) of the current conversation, that is, it stores the messages said by the user or by the bot. """ def __init__(self, **kwargs): super().__init__(**kwargs) @property @abstractmethod def messages(self) -> List[dict]: """Returns the messages from the current conversation.""" pass @property @abstractmethod def role(self) -> str: """Returns the role (e.g. user, assistant...) for this aggregator.""" pass
[docs] @abstractmethod def add_messages(self, messages): """Add the given messages to the conversation.""" pass
[docs] @abstractmethod def set_messages(self, messages): """Reset the conversation with the given messages.""" pass
[docs] @abstractmethod def set_tools(self, tools): """Set LLM tools to be used in the current conversation.""" pass
[docs] @abstractmethod def set_tool_choice(self, tool_choice): """Set the tool choice. This should modify the LLM context.""" pass
[docs] @abstractmethod async def reset(self): """Reset the internals of this aggregator. This should not modify the internal messages. """ pass
[docs] @abstractmethod async def handle_aggregation(self, aggregation: str): """Adds the given aggregation to the aggregator. The aggregator can use a simple list of message or a context. It doesn't not push any frames. """ pass
[docs] @abstractmethod async def push_aggregation(self): """Pushes the current aggregation. For example, iN the case of context aggregation this might push a new context frame. """ pass
[docs] class LLMContextResponseAggregator(BaseLLMResponseAggregator): """This is a base LLM aggregator that uses an LLM context to store the conversation. It pushes `OpenAILLMContextFrame` as an aggregation frame. """ def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs): super().__init__(**kwargs) self._context = context self._role = role self._aggregation: str = "" @property def messages(self) -> List[dict]: return self._context.get_messages() @property def role(self) -> str: return self._role @property def context(self): return self._context
[docs] def get_context_frame(self) -> OpenAILLMContextFrame: return OpenAILLMContextFrame(context=self._context)
[docs] async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM): frame = self.get_context_frame() await self.push_frame(frame, direction)
[docs] def add_messages(self, messages): self._context.add_messages(messages)
[docs] def set_messages(self, messages): self._context.set_messages(messages)
[docs] def set_tools(self, tools: List): self._context.set_tools(tools)
[docs] def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict): self._context.set_tool_choice(tool_choice)
[docs] async def reset(self): self._aggregation = ""
[docs] class LLMUserContextAggregator(LLMContextResponseAggregator): """This is a user LLM aggregator that uses an LLM context to store the conversation. It aggregates transcriptions from the STT service and it has logic to handle multiple scenarios where transcriptions are received between VAD events (`UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame`) or even outside or no VAD events at all. """ def __init__( self, context: OpenAILLMContext, *, params: Optional[LLMUserAggregatorParams] = None, **kwargs, ): super().__init__(context=context, role="user", **kwargs) self._params = params or LLMUserAggregatorParams() if "aggregation_timeout" in kwargs: import warnings with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( "Parameter 'aggregation_timeout' is deprecated, use 'params' instead.", DeprecationWarning, ) self._params.aggregation_timeout = kwargs["aggregation_timeout"] self._user_speaking = False self._bot_speaking = False self._was_bot_speaking = False self._emulating_vad = False self._seen_interim_results = False self._waiting_for_aggregation = False self._aggregation_event = asyncio.Event() self._aggregation_task = None
[docs] async def reset(self): await super().reset() self._was_bot_speaking = False self._seen_interim_results = False self._waiting_for_aggregation = False [await s.reset() for s in self._interruption_strategies]
[docs] async def handle_aggregation(self, aggregation: str): self._context.add_message({"role": self.role, "content": aggregation})
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) if isinstance(frame, StartFrame): # Push StartFrame before start(), because we want StartFrame to be # processed by every processor before any other frame is processed. await self.push_frame(frame, direction) await self._start(frame) elif isinstance(frame, EndFrame): # Push EndFrame before stop(), because stop() waits on the task to # finish and the task finishes when EndFrame is processed. await self.push_frame(frame, direction) await self._stop(frame) elif isinstance(frame, CancelFrame): await self._cancel(frame) await self.push_frame(frame, direction) elif isinstance(frame, InputAudioRawFrame): await self._handle_input_audio(frame) await self.push_frame(frame, direction) elif isinstance(frame, UserStartedSpeakingFrame): await self._handle_user_started_speaking(frame) await self.push_frame(frame, direction) elif isinstance(frame, UserStoppedSpeakingFrame): await self._handle_user_stopped_speaking(frame) await self.push_frame(frame, direction) elif isinstance(frame, BotStartedSpeakingFrame): await self._handle_bot_started_speaking(frame) await self.push_frame(frame, direction) elif isinstance(frame, BotStoppedSpeakingFrame): await self._handle_bot_stopped_speaking(frame) await self.push_frame(frame, direction) elif isinstance(frame, TranscriptionFrame): await self._handle_transcription(frame) elif isinstance(frame, InterimTranscriptionFrame): await self._handle_interim_transcription(frame) elif isinstance(frame, LLMMessagesAppendFrame): self.add_messages(frame.messages) elif isinstance(frame, LLMMessagesUpdateFrame): self.set_messages(frame.messages) elif isinstance(frame, LLMSetToolsFrame): self.set_tools(frame.tools) elif isinstance(frame, LLMSetToolChoiceFrame): self.set_tool_choice(frame.tool_choice) else: await self.push_frame(frame, direction)
async def _process_aggregation(self): """Process the current aggregation and push it downstream.""" aggregation = self._aggregation await self.reset() await self.handle_aggregation(aggregation) frame = OpenAILLMContextFrame(self._context) await self.push_frame(frame)
[docs] async def push_aggregation(self): """Pushes the current aggregation based on interruption strategies and conditions.""" if len(self._aggregation) > 0: if self.interruption_strategies and self._bot_speaking: should_interrupt = await self._should_interrupt_based_on_strategies() if should_interrupt: logger.debug( "Interruption conditions met - pushing BotInterruptionFrame and aggregation" ) await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) await self._process_aggregation() else: logger.debug("Interruption conditions not met - not pushing aggregation") # Don't process aggregation, just reset it await self.reset() else: # No interruption config - normal behavior (always push aggregation) await self._process_aggregation() # Handles the case where both the user and the bot are not speaking, # and the bot was previously speaking before the user interruption. # Normally, when the user stops speaking, new text is expected, # which triggers the bot to respond. However, if no new text # is received, this safeguard ensures # the bot doesn't hang indefinitely while waiting to speak again. elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking: logger.warning("User stopped speaking but no new aggregation received.") # Resetting it so we don't trigger this twice self._was_bot_speaking = False
# TODO: we are not enabling this for now, due to some STT services which can take as long as 2 seconds two return a transcription # So we need more tests and probably make this feature configurable, disabled it by default. # We are just pushing the same previous context to be processed again in this case # await self.push_frame(OpenAILLMContextFrame(self._context)) async def _should_interrupt_based_on_strategies(self) -> bool: """Check if interruption should occur based on configured strategies.""" async def should_interrupt(strategy: BaseInterruptionStrategy): await strategy.append_text(self._aggregation) return await strategy.should_interrupt() return any([await should_interrupt(s) for s in self._interruption_strategies]) async def _start(self, frame: StartFrame): self._create_aggregation_task() async def _stop(self, frame: EndFrame): await self._cancel_aggregation_task() async def _cancel(self, frame: CancelFrame): await self._cancel_aggregation_task() async def _handle_input_audio(self, frame: InputAudioRawFrame): for s in self.interruption_strategies: await s.append_audio(frame.audio, frame.sample_rate) async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame): self._user_speaking = True self._waiting_for_aggregation = True self._was_bot_speaking = self._bot_speaking # If we get a non-emulated UserStartedSpeakingFrame but we are in the # middle of emulating VAD, let's stop emulating VAD (i.e. don't send the # EmulateUserStoppedSpeakingFrame). if not frame.emulated and self._emulating_vad: self._emulating_vad = False async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame): self._user_speaking = False # We just stopped speaking. Let's see if there's some aggregation to # push. If the last thing we saw is an interim transcription, let's wait # pushing the aggregation as we will probably get a final transcription. if len(self._aggregation) > 0: if not self._seen_interim_results: await self.push_aggregation() # Handles the case where both the user and the bot are not speaking, # and the bot was previously speaking before the user interruption. # So in this case we are resetting the aggregation timer elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking: # Reset aggregation timer. self._aggregation_event.set() async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame): self._bot_speaking = True async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame): self._bot_speaking = False async def _handle_transcription(self, frame: TranscriptionFrame): text = frame.text # Make sure we really have some text. if not text.strip(): return self._aggregation += f" {text}" if self._aggregation else text # We just got a final result, so let's reset interim results. self._seen_interim_results = False # Reset aggregation timer. self._aggregation_event.set() async def _handle_interim_transcription(self, _: InterimTranscriptionFrame): self._seen_interim_results = True def _create_aggregation_task(self): if not self._aggregation_task: self._aggregation_task = self.create_task(self._aggregation_task_handler()) async def _cancel_aggregation_task(self): if self._aggregation_task: await self.cancel_task(self._aggregation_task) self._aggregation_task = None async def _aggregation_task_handler(self): while True: try: await asyncio.wait_for( self._aggregation_event.wait(), self._params.aggregation_timeout ) await self._maybe_emulate_user_speaking() except asyncio.TimeoutError: if not self._user_speaking: await self.push_aggregation() # If we are emulating VAD we still need to send the user stopped # speaking frame. if self._emulating_vad: await self.push_frame( EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM ) self._emulating_vad = False finally: self._aggregation_event.clear() async def _maybe_emulate_user_speaking(self): """Emulate user speaking if we got a transcription but it was not detected by VAD. Only do that if the bot is not speaking. """ # Check if we received a transcription but VAD was not able to detect # voice (e.g. when you whisper a short utterance). In that case, we need # to emulate VAD (i.e. user start/stopped speaking), but we do it only # if the bot is not speaking. If the bot is speaking and we really have # a short utterance we don't really want to interrupt the bot. if not self._user_speaking and not self._waiting_for_aggregation: if self._bot_speaking: # If we reached this case and the bot is speaking, let's ignore # what the user said. logger.debug("Ignoring user speaking emulation, bot is speaking.") await self.reset() else: # The bot is not speaking so, let's trigger user speaking # emulation. await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM) self._emulating_vad = True
[docs] class LLMAssistantContextAggregator(LLMContextResponseAggregator): """This is an assistant LLM aggregator that uses an LLM context to store the conversation. It aggregates text frames received between `LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`. """ def __init__( self, context: OpenAILLMContext, *, params: Optional[LLMAssistantAggregatorParams] = None, **kwargs, ): super().__init__(context=context, role="assistant", **kwargs) self._params = params or LLMAssistantAggregatorParams() if "expect_stripped_words" in kwargs: import warnings with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( "Parameter 'expect_stripped_words' is deprecated, use 'params' instead.", DeprecationWarning, ) self._params.expect_stripped_words = kwargs["expect_stripped_words"] self._started = 0 self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {} self._context_updated_tasks: Set[asyncio.Task] = set() @property def has_function_calls_in_progress(self) -> bool: """Check if there are any function calls currently in progress. Returns: bool: True if function calls are in progress, False otherwise """ return bool(self._function_calls_in_progress)
[docs] async def handle_aggregation(self, aggregation: str): self._context.add_message({"role": "assistant", "content": aggregation})
[docs] async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): pass
[docs] async def handle_function_call_result(self, frame: FunctionCallResultFrame): pass
[docs] async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): pass
[docs] async def handle_user_image_frame(self, frame: UserImageRawFrame): pass
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) if isinstance(frame, StartInterruptionFrame): await self._handle_interruptions(frame) await self.push_frame(frame, direction) elif isinstance(frame, LLMFullResponseStartFrame): await self._handle_llm_start(frame) elif isinstance(frame, LLMFullResponseEndFrame): await self._handle_llm_end(frame) elif isinstance(frame, TextFrame): await self._handle_text(frame) elif isinstance(frame, LLMMessagesAppendFrame): self.add_messages(frame.messages) elif isinstance(frame, LLMMessagesUpdateFrame): self.set_messages(frame.messages) elif isinstance(frame, LLMSetToolsFrame): self.set_tools(frame.tools) elif isinstance(frame, LLMSetToolChoiceFrame): self.set_tool_choice(frame.tool_choice) elif isinstance(frame, FunctionCallsStartedFrame): await self._handle_function_calls_started(frame) elif isinstance(frame, FunctionCallInProgressFrame): await self._handle_function_call_in_progress(frame) elif isinstance(frame, FunctionCallResultFrame): await self._handle_function_call_result(frame) elif isinstance(frame, FunctionCallCancelFrame): await self._handle_function_call_cancel(frame) elif isinstance(frame, UserImageRawFrame) and frame.request and frame.request.tool_call_id: await self._handle_user_image_frame(frame) elif isinstance(frame, BotStoppedSpeakingFrame): await self.push_aggregation() await self.push_frame(frame, direction) else: await self.push_frame(frame, direction)
[docs] async def push_aggregation(self): if not self._aggregation: return aggregation = self._aggregation.strip() await self.reset() if aggregation: await self.handle_aggregation(aggregation) # Push context frame await self.push_context_frame() # Push timestamp frame with current time timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) await self.push_frame(timestamp_frame)
async def _handle_interruptions(self, frame: StartInterruptionFrame): await self.push_aggregation() self._started = 0 await self.reset() async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame): function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls] logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}") for function_call in frame.function_calls: self._function_calls_in_progress[function_call.tool_call_id] = None async def _handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): logger.debug( f"{self} FunctionCallInProgressFrame: [{frame.function_name}:{frame.tool_call_id}]" ) await self.handle_function_call_in_progress(frame) self._function_calls_in_progress[frame.tool_call_id] = frame async def _handle_function_call_result(self, frame: FunctionCallResultFrame): logger.debug( f"{self} FunctionCallResultFrame: [{frame.function_name}:{frame.tool_call_id}]" ) if frame.tool_call_id not in self._function_calls_in_progress: logger.warning( f"FunctionCallResultFrame tool_call_id [{frame.tool_call_id}] is not running" ) return del self._function_calls_in_progress[frame.tool_call_id] properties = frame.properties await self.handle_function_call_result(frame) run_llm = False # Run inference if the function call result requires it. if frame.result: if properties and properties.run_llm is not None: # If the tool call result has a run_llm property, use it. run_llm = properties.run_llm elif frame.run_llm is not None: # If the frame is indicating we should run the LLM, do it. run_llm = frame.run_llm else: # If this is the last function call in progress, run the LLM. run_llm = not bool(self._function_calls_in_progress) if run_llm: await self.push_context_frame(FrameDirection.UPSTREAM) # Call the `on_context_updated` callback once the function call result # is added to the context. Also, run this in a separate task to make # sure we don't block the pipeline. if properties and properties.on_context_updated: task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated" task = self.create_task(properties.on_context_updated(), task_name) self._context_updated_tasks.add(task) task.add_done_callback(self._context_updated_task_finished) async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame): logger.debug( f"{self} FunctionCallCancelFrame: [{frame.function_name}:{frame.tool_call_id}]" ) if frame.tool_call_id not in self._function_calls_in_progress: return if self._function_calls_in_progress[frame.tool_call_id].cancel_on_interruption: await self.handle_function_call_cancel(frame) del self._function_calls_in_progress[frame.tool_call_id] async def _handle_user_image_frame(self, frame: UserImageRawFrame): logger.debug( f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]" ) if frame.request.tool_call_id not in self._function_calls_in_progress: logger.warning( f"UserImageRawFrame tool_call_id [{frame.request.tool_call_id}] is not running" ) return del self._function_calls_in_progress[frame.request.tool_call_id] await self.handle_user_image_frame(frame) await self.push_aggregation() await self.push_context_frame(FrameDirection.UPSTREAM) async def _handle_llm_start(self, _: LLMFullResponseStartFrame): self._started += 1 async def _handle_llm_end(self, _: LLMFullResponseEndFrame): self._started -= 1 await self.push_aggregation() async def _handle_text(self, frame: TextFrame): if not self._started: return if self._params.expect_stripped_words: self._aggregation += f" {frame.text}" if self._aggregation else frame.text else: self._aggregation += frame.text def _context_updated_task_finished(self, task: asyncio.Task): self._context_updated_tasks.discard(task) # The task is finished so this should exit immediately. We need to do # this because otherwise the task manager would report a dangling task # if we don't remove it. asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
[docs] class LLMUserResponseAggregator(LLMUserContextAggregator): def __init__( self, messages: Optional[List[dict]] = None, *, params: Optional[LLMUserAggregatorParams] = None, **kwargs, ): super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
[docs] async def push_aggregation(self): if len(self._aggregation) > 0: await self.handle_aggregation(self._aggregation) # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. await self.reset() frame = LLMMessagesFrame(self._context.messages) await self.push_frame(frame)
[docs] class LLMAssistantResponseAggregator(LLMAssistantContextAggregator): def __init__( self, messages: Optional[List[dict]] = None, *, params: Optional[LLMAssistantAggregatorParams] = None, **kwargs, ): super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
[docs] async def push_aggregation(self): if len(self._aggregation) > 0: await self.handle_aggregation(self._aggregation) # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. await self.reset() frame = LLMMessagesFrame(self._context.messages) await self.push_frame(frame)