#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import io
import wave
from typing import AsyncGenerator, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.tracing.service_decorators import traced_tts
try:
from groq import AsyncGroq
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Groq, you need to `pip install pipecat-ai[groq]`.")
raise Exception(f"Missing module: {e}")
[docs]
class GroqTTSService(TTSService):
GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate
def __init__(
self,
*,
api_key: str,
output_format: str = "wav",
params: Optional[InputParams] = None,
model_name: str = "playai-tts",
voice_id: str = "Celeste-PlayAI",
sample_rate: Optional[int] = GROQ_SAMPLE_RATE,
**kwargs,
):
if sample_rate != self.GROQ_SAMPLE_RATE:
logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ")
super().__init__(
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
params = params or GroqTTSService.InputParams()
self._api_key = api_key
self._model_name = model_name
self._output_format = output_format
self._voice_id = voice_id
self._params = params
self._settings = {
"model": model_name,
"voice_id": voice_id,
"output_format": output_format,
"language": str(params.language) if params.language else "en",
"speed": params.speed,
"sample_rate": sample_rate,
}
self._client = AsyncGroq(api_key=self._api_key)
[docs]
def can_generate_metrics(self) -> bool:
return True
[docs]
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
measuring_ttfb = True
await self.start_ttfb_metrics()
yield TTSStartedFrame()
try:
response = await self._client.audio.speech.create(
model=self._model_name,
voice=self._voice_id,
response_format=self._output_format,
input=text,
)
async for data in response.iter_bytes():
if measuring_ttfb:
await self.stop_ttfb_metrics()
measuring_ttfb = False
with wave.open(io.BytesIO(data)) as w:
channels = w.getnchannels()
frame_rate = w.getframerate()
num_frames = w.getnframes()
bytes = w.readframes(num_frames)
yield TTSAudioRawFrame(bytes, frame_rate, channels)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield TTSStoppedFrame()