#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenAI LLM service implementation with context aggregators."""
import json
from dataclasses import dataclass
from typing import Any, Optional
from pipecat.frames.frames import (
FunctionCallCancelFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
UserImageRawFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.base_llm import BaseOpenAILLMService
[docs]
@dataclass
class OpenAIContextAggregatorPair:
"""Pair of OpenAI context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for processing user messages.
_assistant: Assistant context aggregator for processing assistant messages.
"""
_user: "OpenAIUserContextAggregator"
_assistant: "OpenAIAssistantContextAggregator"
[docs]
def user(self) -> "OpenAIUserContextAggregator":
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
[docs]
def assistant(self) -> "OpenAIAssistantContextAggregator":
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
[docs]
class OpenAILLMService(BaseOpenAILLMService):
"""OpenAI LLM service implementation.
Provides a complete OpenAI LLM service with context aggregation support.
Uses the BaseOpenAILLMService for core functionality and adds OpenAI-specific
context aggregator creation.
Args:
model: The OpenAI model name to use. Defaults to "gpt-4.1".
params: Input parameters for model configuration.
**kwargs: Additional arguments passed to the parent BaseOpenAILLMService.
"""
def __init__(
self,
*,
model: str = "gpt-4.1",
params: Optional[BaseOpenAILLMService.InputParams] = None,
**kwargs,
):
super().__init__(model=model, params=params, **kwargs)
[docs]
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIContextAggregatorPair:
"""Create OpenAI-specific context aggregators.
Creates a pair of context aggregators optimized for OpenAI's message format,
including support for function calls, tool usage, and image handling.
Args:
context: The LLM context to create aggregators for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
OpenAIContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
OpenAIContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
user = OpenAIUserContextAggregator(context, params=user_params)
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
[docs]
class OpenAIUserContextAggregator(LLMUserContextAggregator):
"""OpenAI-specific user context aggregator.
Handles aggregation of user messages for OpenAI LLM services.
Inherits all functionality from the base LLMUserContextAggregator.
"""
pass
[docs]
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
"""OpenAI-specific assistant context aggregator.
Handles aggregation of assistant messages for OpenAI LLM services,
with specialized support for OpenAI's function calling format,
tool usage tracking, and image message handling.
"""
[docs]
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
"""Handle a function call in progress.
Adds the function call to the context with an IN_PROGRESS status
to track ongoing function execution.
Args:
frame: Frame containing function call progress information.
"""
self._context.add_message(
{
"role": "assistant",
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
"type": "function",
}
],
}
)
self._context.add_message(
{
"role": "tool",
"content": "IN_PROGRESS",
"tool_call_id": frame.tool_call_id,
}
)
[docs]
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle the result of a function call.
Updates the context with the function call result, replacing any
previous IN_PROGRESS status.
Args:
frame: Frame containing the function call result.
"""
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
else:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "COMPLETED"
)
[docs]
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
"""Handle a cancelled function call.
Updates the context to mark the function call as cancelled.
Args:
frame: Frame containing the function call cancellation information.
"""
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: Any
):
for message in self._context.messages:
if (
message["role"] == "tool"
and message["tool_call_id"]
and message["tool_call_id"] == tool_call_id
):
message["content"] = result
[docs]
async def handle_user_image_frame(self, frame: UserImageRawFrame):
"""Handle a user image frame from a function call request.
Marks the associated function call as completed and adds the image
to the context for processing.
Args:
frame: Frame containing the user image and request context.
"""
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)