diff --git a/docs/voice/pipeline.md b/docs/voice/pipeline.md index 8cf5dafe..2aa7cad9 100644 --- a/docs/voice/pipeline.md +++ b/docs/voice/pipeline.md @@ -73,3 +73,293 @@ async for event in result.stream(): ### Interruptions The Agents SDK currently does not support any built-in interruptions support for [`StreamedAudioInput`][agents.voice.input.StreamedAudioInput]. Instead for every detected turn it will trigger a separate run of your workflow. If you want to handle interruptions inside your application you can listen to the [`VoiceStreamEventLifecycle`][agents.voice.events.VoiceStreamEventLifecycle] events. `turn_started` will indicate that a new turn was transcribed and processing is beginning. `turn_ended` will trigger after all the audio was dispatched for a respective turn. You could use these events to mute the microphone of the speaker when the model starts a turn and unmute it after you flushed all the related audio for a turn. + +Once the pipeline is done processing all turns, the `stream()` method will complete and the context manager will exit. + +## Real-time Voice Pipeline + +The SDK includes a `RealtimeVoicePipeline` designed for direct, bidirectional voice interaction with newer, real-time capable models like OpenAI's `gpt-4o-realtime-preview`. This pipeline differs significantly from the standard `VoicePipeline`: + +- **Direct Voice-to-Voice:** It sends your audio directly to the real-time LLM and receives audio back from the LLM. There are no separate STT (Speech-to-Text) or TTS (Text-to-Speech) steps managed by this pipeline. The LLM handles both transcription and speech generation internally. +- **Integrated Tool Calls:** If the LLM decides to use a tool, the pipeline will automatically execute it using the tools you provided during initialization and send the result back to the LLM. The pipeline emits `VoiceStreamEventToolCall` events so your application can log or display information about tool usage, but it does not need to perform any action in response to these events. +- **Continuous Streaming:** It's designed for continuous audio input and output, facilitating more natural conversational turn-taking. + +### Usage + +The `RealtimeVoicePipeline` follows a similar pattern to the standard `VoicePipeline`: + +1. Create a `StreamedAudioInput` instance +2. Configure a `VoicePipelineConfig` with real-time specific settings +3. Initialize the pipeline with a real-time model and any tools +4. Call `run()` to get a result that can be streamed +5. Process the events from the stream + +#### Basic example: + +```python +from agents.voice import ( + RealtimeVoicePipeline, + StreamedAudioInput, + VoicePipelineConfig +) +from agents.voice.models.sdk_realtime import SDKRealtimeLLM +from dataclasses import dataclass + +# Define a simple context class for state management (optional) +@dataclass +class MyAppContext: + """Context for the voice assistant.""" + user_name: str = "User" + interaction_count: int = 0 + +# Create the input, config, and model +input_stream = StreamedAudioInput() +config = VoicePipelineConfig( + realtime_settings={ + "turn_detection": "server_vad", # Use server-side voice activity detection + "system_message": "You are a helpful assistant.", + } +) +model = SDKRealtimeLLM(model_name="gpt-4o-realtime-preview") + +# Create an app context instance (optional) +app_context = MyAppContext() + +# Create the pipeline with tools and shared context +pipeline = RealtimeVoicePipeline( + model=model, + tools=[get_weather, get_time], + config=config, + shared_context=app_context, # Optional: shared state for context-aware tools +) + +# Start the pipeline +result = await pipeline.run(input_stream) + +# Process events from the pipeline +async for event in result.stream(): + # Handle different event types + if isinstance(event, VoiceStreamEventAudio): + # Play this audio to the user + play_audio(event.data) + elif isinstance(event, VoiceStreamEventToolCall): + # Log tool usage (execution is automatic) + log_tool_call(event.tool_name, event.arguments) + # Handle other event types... + +# Continuously send audio chunks to the pipeline +# There's no need to signal "end of audio" - the model handles turn-taking +while True: + audio_chunk = record_audio_chunk() + await input_stream.queue.put(audio_chunk) + + # If the application is closing, close the input + if stopping: + await input_stream.close() + break +``` + +### Using Shared Context with Tools + +The `RealtimeVoicePipeline` supports passing a shared context object to tools, allowing them to access and modify shared state across multiple interactions. This is useful for building more complex voice applications that need to maintain state, such as: + +- Tracking user preferences +- Maintaining conversation history +- Counting interactions +- Storing user information + +#### Setting up a shared context + +To use shared context with tools: + +1. Define a context class (typically a dataclass) to hold your application state +2. Create an instance of this class +3. Pass it to the `RealtimeVoicePipeline` using the `shared_context` parameter +4. Create tools that accept a `RunContextWrapper[YourContextType]` as their first parameter + +```python +from dataclasses import dataclass +from agents.run_context import RunContextWrapper +from agents.tool import function_tool + +# Define your context class +@dataclass +class MyAppContext: + """Context for the voice assistant.""" + user_name: str + interaction_count: int = 0 + +# Create a context-aware tool +@function_tool +def greet_user_and_count(context: RunContextWrapper[MyAppContext]) -> str: + """Greets the user by name and counts interactions.""" + # Access and modify the context + context.context.interaction_count += 1 + + return f"Hello {context.context.user_name}! This is interaction number {context.context.interaction_count}." + +# Create another context-aware tool +@function_tool +def get_user_details(context: RunContextWrapper[MyAppContext]) -> dict: + """Gets user details from the context.""" + return { + "user_name": context.context.user_name, + "interaction_count": context.context.interaction_count + } + +# Create your application context +app_context = MyAppContext(user_name="Alice", interaction_count=0) + +# Create the pipeline with shared context +pipeline = RealtimeVoicePipeline( + model=model, + tools=[get_weather, get_time, greet_user_and_count, get_user_details], + config=config, + shared_context=app_context, # Pass the context here +) +``` + +#### How it works + +1. The `RealtimeVoicePipeline` passes the shared context to its internal `ToolExecutor` +2. When the LLM calls a tool, the `ToolExecutor` checks if the tool's first parameter is named `context` +3. If it is, the executor wraps your context object in a `RunContextWrapper` and passes it to the tool +4. The tool can then access and modify your context object via `context.context` +5. Since all tools share the same context object, changes made by one tool are visible to other tools in future calls + +This mechanism allows your tools to maintain shared state across turns and interactions in your voice application, without needing to set up a separate state management system. + +#### Context-Aware vs. Standard Tools + +You can mix both context-aware and standard tools in the same `RealtimeVoicePipeline`: + +```python +# A standard tool (no context parameter) +@function_tool +def get_weather(city: str) -> dict: + """Gets the weather for the specified city.""" + return {"temperature": 72, "condition": "sunny"} + +# A context-aware tool (has context parameter) +@function_tool +def update_user_preference(context: RunContextWrapper[MyAppContext], preference: str, value: str) -> str: + """Updates a user preference in the context.""" + if not hasattr(context.context, "preferences"): + context.context.preferences = {} + context.context.preferences[preference] = value + return f"Updated {preference} to {value}" +``` + +**When to use standard tools:** + +- For stateless operations that don't need to remember information between calls +- For simple lookups or calculations based solely on the input parameters +- When integration with external APIs or services doesn't require user-specific state + +**When to use context-aware tools:** + +- When tools need to access or modify shared state +- For personalization features that adapt to the user +- To implement features that track usage or interactions +- When information gathered in one tool call needs to be available to another tool + +**Important notes:** + +- The first parameter of a context-aware tool must be named `context` and should have a type annotation of `RunContextWrapper[YourContextType]` +- Type hints are recommended but not required; the parameter name `context` is sufficient for the tool to be detected as context-aware +- The actual object inside `context.context` will be the instance you passed to `shared_context` when creating the pipeline +- All context-aware tools see the same context instance, so changes are immediately visible to all tools + +### Turn Detection Modes + +The realtime models can operate in different turn detection modes, controlled via the `turn_detection` setting: + +- `"server_vad"` (default): The server automatically detects when the user has stopped speaking using Voice Activity Detection and starts responding. +- `"manual"`: Your application explicitly signals when the user has finished speaking by calling `await llm_session.commit_audio_buffer()`. +- `None`: Same as `"server_vad"` - the server handles turn detection automatically. + +### Implementing Push-to-Talk + +In push-to-talk mode, the application sends audio only when the user activates a button or key: + +```python +# Start continuous silent audio (required for maintaining the connection) +async def send_continuous_audio(): + while True: + if push_to_talk_active: + # Send real audio when button is pressed + audio = get_microphone_audio() + else: + # Send silence when button is not pressed + audio = np.zeros(CHUNK_SIZE, dtype=np.int16) + + await input_stream.queue.put(audio) + await asyncio.sleep(CHUNK_DURATION) # Simulate real-time pacing + +# When user releases the push-to-talk button +async def on_push_to_talk_released(): + # Optional: For manual turn detection, commit the buffer + if turn_detection == "manual": + await llm_session.commit_audio_buffer() +``` + +### Event Handling + +When processing events from a `RealtimeVoicePipeline`, you'll handle these event types: + +- `VoiceStreamEventAudio`: Contains audio data from the LLM to play back to the user +- `VoiceStreamEventLifecycle`: Indicates session lifecycle events (e.g., "turn_started", "turn_ended", "session_ended") +- `VoiceStreamEventToolCall`: Provides information about tool calls being executed by the pipeline +- `VoiceStreamEventError`: Indicates an error condition + +### Key Differences & Important Notes + +- **Continuous Audio**: The realtime pipeline expects continuous audio input, not discrete turns ending with a `None` sentinel. Use `input_stream.close()` only when shutting down the pipeline entirely. +- **Event Types**: You'll receive `VoiceStreamEventToolCall` events for informational purposes when tools are used. The pipeline automatically executes registered tools and sends results back to the LLM - no action is needed from your application. +- **No Separate STT/TTS Events**: You will receive `VoiceStreamEventAudio` directly from the LLM. There are no separate events indicating STT transcription completion or explicit text-to-speech stages within this pipeline's event stream. +- **Configuration**: Real-time model specific settings (like assistant voice, system message, or turn detection mode) are passed via the `realtime_settings` dictionary within `VoicePipelineConfig`. +- **Audio Format**: The OpenAI realtime models currently require **16-bit PCM at a 24 kHz sample rate, mono, little-endian** for both _input_ and _output_ when you use the default `pcm16` format. Make sure your microphone capture (`StreamedAudioInput`) and speaker playback are configured for **24 kHz** to avoid chip-munk / slow-motion artefacts. + +```python +INPUT_SAMPLE_RATE = 24_000 # 24 kHz for mic capture +OUTPUT_SAMPLE_RATE = 24_000 # 24 kHz for TTS playback +``` + +Failing to match this sample-rate is the most common cause of distorted or "slow" audio. + +For complete working examples, see: + +- [`realtime_assistant.py`](https://github.com/openai/openai-agents-python/blob/main/examples/voice/realtime_assistant.py) - Basic example with simulated audio +- [`continuous_realtime_assistant.py`](https://github.com/openai/openai-agents-python/blob/main/examples/voice/continuous_realtime_assistant.py) - Example showing continuous streaming with push-to-talk simulation + +Note that these examples require approved access to the OpenAI `gpt-4o-realtime-preview` model. + +### New transcription events + +When you enable `input_audio_transcription` in the session configuration (the realtime pipeline does this automatically), the server can stream _your_ microphone audio back as text. Two new event types are surfaced by the SDK so you can inspect what the model thinks it heard: + +- `RealtimeEventInputAudioTranscriptionDelta` – incremental partial transcripts +- `RealtimeEventInputAudioTranscriptionCompleted` – the final transcript for that user turn + +```python +elif isinstance(event, RealtimeEventInputAudioTranscriptionDelta): + print("you (partial):", event.delta) +elif isinstance(event, RealtimeEventInputAudioTranscriptionCompleted): + print("you (final):", event.transcript) +``` + +These are invaluable for debugging cases where echo or background noise is being mis-interpreted by the model. + +### Echo & feedback mitigation + +If you hear the assistant repeatedly greeting you ("Hello again!") it usually means your microphone is re-capturing the speaker audio. Combine these techniques: + +1. Enable the built-in echo / noise suppression with + + ```python + realtime_settings={"input_audio_noise_reduction": {}} + ``` + +2. In push-to-talk interfaces, _pause_ mic streaming for ~300 ms after the last assistant audio chunk. See `ASSISTANT_AUDIO_SILENCE_BUFFER_S` in `continuous_realtime_assistant.py`. + +3. Use headphones for the cleanest experience. diff --git a/examples/voice/realtime_assistant.py b/examples/voice/realtime_assistant.py new file mode 100644 index 00000000..0e09b8c4 --- /dev/null +++ b/examples/voice/realtime_assistant.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python +""" +This example shows how to use the RealtimeVoicePipeline with continuous audio streaming +using a real microphone for input and speakers for output. + +Requirements: +1. OpenAI API key with APPROVED access to the gpt-4o-realtime-preview model +2. Python 3.10+ +3. Required packages: openai, numpy, sounddevice +4. A working microphone and speaker setup. + +Important Note: + Access to gpt-4o-realtime-preview requires special approval from OpenAI. + If you receive WebSocket connection closures with code 1000, it's likely that your + API key does not have approved access to the model yet. + + Visit https://platform.openai.com/docs/guides/realtime for more information + on applying for access to the realtime API. + +Usage: + python realtime_assistant.py +""" + +import asyncio +import logging +import os +import time +from typing import Dict, Any +from dataclasses import dataclass + +import numpy as np +import sounddevice as sd # For microphone and speaker I/O + +from agents.voice import ( + RealtimeVoicePipeline, + StreamedAudioInput, + VoicePipelineConfig, + VoiceStreamEvent, + VoiceStreamEventLifecycle, + VoiceStreamEventToolCall, + VoiceStreamEventError, + VoiceStreamEventAudio, +) +from agents.tool import function_tool, Tool +from agents.voice.models.sdk_realtime import SDKRealtimeLLM +from agents.run_context import RunContextWrapper + +# Import the new event types from our SDK +from agents.voice.realtime.model import ( + RealtimeEventResponseDone, + RealtimeEventRateLimitsUpdated, + RealtimeEventInputAudioTranscriptionDelta, + RealtimeEventInputAudioTranscriptionCompleted, +) +import dotenv + +# Load environment variables +dotenv.load_dotenv() + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("realtime_assistant") + + +# Define a dataclass for our application context +@dataclass +class MyAppContext: + """A simple context for the realtime voice assistant example.""" + + user_name: str + interaction_count: int = 0 + + +# Define some sample tools +@function_tool +def get_weather(city: str) -> Dict[str, Any]: + """Gets the current weather for a given city.""" + logger.info(f"Getting weather for {city}") + return {"temperature": 22, "condition": "sunny", "humidity": 60} + + +@function_tool +def get_time(timezone: str = "UTC") -> Dict[str, Any]: + """Gets the current time in the specified timezone.""" + logger.info(f"Getting time for timezone {timezone}") + return {"time": time.strftime("%H:%M:%S", time.gmtime()), "timezone": timezone} + + +# Define a context-aware tool +@function_tool +def greet_user_and_count(context: RunContextWrapper[MyAppContext]) -> str: + """Greets the user by name and counts interactions.""" + logger.info(f"greet_user_and_count called with context: {context}") + # Increment the interaction count + context.context.interaction_count += 1 + + logger.info( + f"Greeting user: {context.context.user_name}, " + f"Interaction count: {context.context.interaction_count}" + ) + + return f"Hello {context.context.user_name}! This is interaction number {context.context.interaction_count}." + + +# Another context-aware tool that reads but doesn't modify the context +@function_tool +def get_user_details(context: RunContextWrapper[MyAppContext]) -> Dict[str, Any]: + """Gets the user's details from the context.""" + logger.info(f"get_user_details called with context: {context}") + + logger.info( + f"Returning user details: name={context.context.user_name}, count={context.context.interaction_count}" + ) + return { + "user_name": context.context.user_name, + "interaction_count": context.context.interaction_count, + } + + +# Get the OpenAI API key from environment variables +api_key = os.environ.get("OPENAI_API_KEY") +if not api_key: + logger.error("OPENAI_API_KEY environment variable is not set.") + logger.error( + "Please set your OpenAI API key with access to gpt-4o-realtime-preview model." + ) + logger.error("You can get API keys from https://platform.openai.com/api-keys") + exit(1) +else: + # Show first and last 4 characters of the key for debugging + masked_key = f"{api_key[:8]}********************" + logger.info(f"Using OpenAI API key: {masked_key}") + + +# Audio settings +INPUT_SAMPLE_RATE = 24000 # OpenAI Realtime API expects 24kHz input for pcm16 +OUTPUT_SAMPLE_RATE = 24000 # OpenAI TTS audio is 24kHz for pcm16 +CHANNELS = 1 +INPUT_DTYPE = "int16" # Microphone input type +OUTPUT_DTYPE = np.int16 # Speaker output type, OpenAI sends int16 PCM +CHUNK_DURATION_S = 0.1 # Send audio in 100ms chunks +INPUT_CHUNK_SIZE = int(INPUT_SAMPLE_RATE * CHUNK_DURATION_S) + +# Buffer time (seconds) after last assistant audio chunk before we resume mic capture +ASSISTANT_AUDIO_SILENCE_BUFFER_S = 0.3 + + +async def main(): + logger.info("Initializing RealtimeVoicePipeline...") + + # Create the SDK-based OpenAI realtime model + model = SDKRealtimeLLM( + model_name="gpt-4o-realtime-preview", + api_key=api_key, + ) + + # Create an audio input and pipeline config with server-side VAD + config = VoicePipelineConfig( + realtime_settings={ + "turn_detection": "server_vad", # Use server-side VAD + "assistant_voice": "alloy", + "system_message": "You are a helpful assistant that responds concisely. You can use the greet_user_and_count tool to greet the user by name and the get_user_details tool to retrieve information about the user.", + # Enable server-side noise / echo reduction + "input_audio_noise_reduction": {}, + } + ) + input_stream = StreamedAudioInput() + + # Create our application context + app_context = MyAppContext(user_name="Anurag", interaction_count=0) + + # Create the realtime pipeline with shared context + pipeline = RealtimeVoicePipeline( + model=model, + tools=[get_weather, get_time, greet_user_and_count, get_user_details], + config=config, + shared_context=app_context, # Pass the context to the pipeline + ) + + # Track events and errors + event_count = 0 + error_occurred = False + should_continue_streaming = True # Controls the main audio streaming loop + + # This example simulates a "Push-to-Talk" interface + # The push_to_talk_active flag controls when audio is being sent + push_to_talk_active = asyncio.Event() # Use an asyncio.Event for push-to-talk state + + # Timestamp of the most recent assistant audio chunk that was played + last_assistant_audio_ts: float = 0.0 + + # Function to handle microphone input in a separate task + async def mic_input_loop(): + nonlocal should_continue_streaming, error_occurred + logger.info("Starting microphone input loop...") + try: + with sd.InputStream( + samplerate=INPUT_SAMPLE_RATE, channels=CHANNELS, dtype=INPUT_DTYPE + ) as mic: + logger.info( + f"Microphone opened. Default input device: {sd.query_devices(kind='input')['name']}" + ) + while should_continue_streaming: + # Wait until push-to-talk is active + await push_to_talk_active.wait() + + # Check if enough data is available to read our desired chunk size + if mic.read_available >= INPUT_CHUNK_SIZE: + data, overflowed = mic.read(INPUT_CHUNK_SIZE) + if overflowed: + logger.warning("Microphone input overflowed!") + + # Only forward to server if enough time has passed since assistant spoke + if ( + time.time() - last_assistant_audio_ts + ) >= ASSISTANT_AUDIO_SILENCE_BUFFER_S: + if data.size > 0: + logger.debug( + f"Forwarding {data.size} samples to server (PTT active)." + ) + await input_stream.queue.put(data.astype(np.int16)) + else: + # Discard or optionally monitor but don't send + logger.debug( + "Discarding mic samples to avoid echo (within buffer window)." + ) + else: + # Not enough data yet, yield control briefly + await asyncio.sleep( + 0.001 + ) # Sleep for a very short duration (1ms) + + except sd.PortAudioError as pae: + logger.error(f"PortAudioError in microphone loop: {pae}") + logger.error( + "This might be due to no microphone being available or permissions issues." + ) + logger.error(f"Available input devices: {sd.query_devices(kind='input')}") + error_occurred = True + except Exception as e: + logger.error(f"Error in microphone input loop: {e}", exc_info=True) + error_occurred = True + finally: + logger.info("Microphone input loop ended.") + + # Run the pipeline + try: + # Initialize and start the speaker output stream + speaker_stream = sd.OutputStream( + samplerate=OUTPUT_SAMPLE_RATE, channels=CHANNELS, dtype=OUTPUT_DTYPE + ) + speaker_stream.start() + logger.info( + f"Speaker output stream started. Default output device: {sd.query_devices(kind='output')['name']}" + ) + + # Start the pipeline + result = await pipeline.run(input_stream) + logger.info("Pipeline started successfully. Listening for events...") + + # Start the microphone input task + mic_task = asyncio.create_task(mic_input_loop()) + + # Simulate push-to-talk actions with a timer + async def toggle_push_to_talk_simulation(): + nonlocal should_continue_streaming + await asyncio.sleep(2) # Initial delay before first interaction + try: + while should_continue_streaming: + logger.info( + "🎤 Press Enter to START simulated push-to-talk, or type 'q' to quit..." + ) + action = await asyncio.to_thread( + input, "" + ) # Non-blocking input for demo + if action.lower() == "q": + should_continue_streaming = False + push_to_talk_active.set() # Unblock mic loop if it was waiting + break + + logger.info("🎤 PUSH-TO-TALK: ON (Recording from microphone...)") + push_to_talk_active.set() # Signal mic_input_loop to send audio + + logger.info("🎤 Press Enter to STOP simulated push-to-talk...") + await asyncio.to_thread(input, "") + logger.info("🎤 PUSH-TO-TALK: OFF (Stopped recording)") + push_to_talk_active.clear() # Signal mic_input_loop to stop sending + + # Optional: Send a commit if manual VAD was used (not in this example) + # if config.realtime_settings.get("turn_detection") == "manual": + # await pipeline.commit_audio_buffer() # Assuming pipeline exposes this + + except asyncio.CancelledError: + logger.info("Push-to-talk simulation cancelled.") + finally: + logger.info("Push-to-talk simulation ended.") + push_to_talk_active.set() # Ensure mic loop can exit if it was waiting + + # Start the push-to-talk simulation + ptt_simulation_task = asyncio.create_task(toggle_push_to_talk_simulation()) + + # Process events from the pipeline + async for event in result.stream(): + if not should_continue_streaming: + break + event_count += 1 + + if isinstance(event, VoiceStreamEventLifecycle): + logger.info(f"Lifecycle event: {event.event}") + if event.event == "session_ended": + logger.info("Real-time session ended.") + should_continue_streaming = False + break + + elif isinstance(event, VoiceStreamEventAudio): + if event.data is not None and speaker_stream: + logger.info(f"Received audio: {len(event.data)} bytes. Playing...") + speaker_stream.write(event.data.astype(OUTPUT_DTYPE)) + # Update last audio timestamp for mic gating + last_assistant_audio_ts = time.time() + else: + logger.info( + "Received empty audio data or speaker stream not active." + ) + + elif isinstance(event, VoiceStreamEventToolCall): + logger.info(f"Tool call: {event.tool_name}({event.arguments})") + + elif isinstance(event, VoiceStreamEventError): + logger.error(f"Pipeline Error: {event.error}") + error_occurred = True + if "1000" in str(event.error): + logger.warning("WebSocket was closed normally (code 1000).") + logger.warning( + "This typically happens when your API key lacks access to the gpt-4o-realtime-preview model." + ) + should_continue_streaming = False + break + + # Handle newly defined RealtimeEvent types that are passed through StreamedRealtimeResult + elif isinstance(event, RealtimeEventResponseDone): + logger.info(f"Assistant response for item '{event.item_id}' is done.") + + elif isinstance(event, RealtimeEventRateLimitsUpdated): + logger.info(f"Rate limits updated by server: {event.data}") + + elif isinstance(event, RealtimeEventInputAudioTranscriptionDelta): + logger.info( + f"[TRANSCRIPTION DELTA] item {event.item_id} idx {event.content_index}: {event.delta}" + ) + + elif isinstance(event, RealtimeEventInputAudioTranscriptionCompleted): + logger.info( + f"[TRANSCRIPTION COMPLETED] item {event.item_id}: {event.transcript}" + ) + + else: + logger.info(f"Unknown Event: {event.type} - {event}") + + # Wait for the push-to-talk simulation to complete or be cancelled + if not ptt_simulation_task.done(): + ptt_simulation_task.cancel() + try: + await ptt_simulation_task + except asyncio.CancelledError: + pass + + logger.info(f"Total events processed: {event_count}") + + # Print the final interaction count from the context + logger.info(f"Final interaction count: {app_context.interaction_count}") + + # Provide troubleshooting information if needed + if error_occurred or event_count <= 1: # <=1 because turn_started is an event + logger.error(f"Error occurred: {error_occurred}") + + except KeyboardInterrupt: + logger.info("Keyboard interrupt detected, stopping...") + except Exception as e: + logger.error(f"Main loop error: {e}", exc_info=True) + finally: + logger.info("Cleaning up resources...") + should_continue_streaming = False # Signal all loops to stop + + if mic_task and not mic_task.done(): + push_to_talk_active.set() # Unblock mic_input_loop if waiting + mic_task.cancel() + try: + await mic_task + except asyncio.CancelledError: + pass + logger.info("Microphone task stopped.") + + if speaker_stream: + speaker_stream.stop() + speaker_stream.close() + logger.info("Speaker stream closed.") + + if input_stream and not input_stream.is_closed: + await input_stream.close() + logger.info("Audio input stream closed.") + + logger.info("Application shutdown complete.") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nExiting by user request.") diff --git a/examples/voice/static/main.py b/examples/voice/static/main.py index 1b9e2024..27c52e8a 100644 --- a/examples/voice/static/main.py +++ b/examples/voice/static/main.py @@ -12,7 +12,7 @@ VoicePipeline, ) -from .util import AudioPlayer, record_audio +from util import AudioPlayer, record_audio """ This is a simple example that uses a recorded audio buffer. Run it via: diff --git a/pyproject.toml b/pyproject.toml index 22b028ae..9ae06a4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ Homepage = "https://github.com/openai/openai-agents-python" Repository = "https://github.com/openai/openai-agents-python" [project.optional-dependencies] -voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"] +voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=11.0, <16"] viz = ["graphviz>=0.17"] litellm = ["litellm>=1.67.4.post1, <2"] diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index bbfe1885..5a494152 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -95,10 +95,14 @@ async def invoke_mcp_tool( f"Invalid JSON input for tool {tool.name}: {input_json}" ) from e + # NOT GOOD: + # https://github.com/modelcontextprotocol/modelcontextprotocol/discussions/234 + # adding the context as one more field in the input to tool, since dynamic context is not supported by mcp. + json_data["juspay_meta_info"] = context.context if _debug.DONT_LOG_TOOL_DATA: logger.debug(f"Invoking MCP tool {tool.name}") else: - logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") + logger.debug(f"Invoking Juspay MCP tool {tool.name} with input {json_data}") try: result = await server.call_tool(tool.name, json_data) diff --git a/src/agents/voice/__init__.py b/src/agents/voice/__init__.py index e11ee446..36704fbb 100644 --- a/src/agents/voice/__init__.py +++ b/src/agents/voice/__init__.py @@ -1,4 +1,10 @@ -from .events import VoiceStreamEvent, VoiceStreamEventAudio, VoiceStreamEventLifecycle +from .events import ( + VoiceStreamEvent, + VoiceStreamEventAudio, + VoiceStreamEventLifecycle, + VoiceStreamEventToolCall, + VoiceStreamEventError, +) from .exceptions import STTWebsocketConnectionError from .input import AudioInput, StreamedAudioInput from .model import ( @@ -15,7 +21,21 @@ from .models.openai_tts import OpenAITTSModel from .pipeline import VoicePipeline from .pipeline_config import VoicePipelineConfig +from .pipeline_realtime import RealtimeVoicePipeline from .result import StreamedAudioResult +from .result_realtime import StreamedRealtimeResult +from .realtime.model import ( + RealtimeLLMModel, + RealtimeSession, + RealtimeEvent, + RealtimeEventSessionBegins, + RealtimeEventAudioChunk, + RealtimeEventTextDelta, + RealtimeEventToolCall, + RealtimeEventSessionEnds, + RealtimeEventError, +) +from .models.sdk_realtime import SDKRealtimeLLM from .utils import get_sentence_based_splitter from .workflow import ( SingleAgentVoiceWorkflow, @@ -50,4 +70,18 @@ "StreamedTranscriptionSession", "OpenAISTTTranscriptionSession", "STTWebsocketConnectionError", + "VoiceStreamEventError", + "RealtimeVoicePipeline", + "StreamedRealtimeResult", + "RealtimeLLMModel", + "RealtimeSession", + "RealtimeEvent", + "RealtimeEventSessionBegins", + "RealtimeEventAudioChunk", + "RealtimeEventTextDelta", + "RealtimeEventToolCall", + "RealtimeEventSessionEnds", + "RealtimeEventError", + "VoiceStreamEventToolCall", + "SDKRealtimeLLM", ] diff --git a/src/agents/voice/events.py b/src/agents/voice/events.py index bdcd0815..920d7480 100644 --- a/src/agents/voice/events.py +++ b/src/agents/voice/events.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal, Union +from typing import Literal, Union, Any from typing_extensions import TypeAlias @@ -41,7 +41,26 @@ class VoiceStreamEventError: """The type of event.""" +@dataclass +class VoiceStreamEventToolCall: + """Streaming event indicating a tool call from a real-time pipeline.""" + + tool_call_id: str + """The unique ID for this tool call instance.""" + + tool_name: str + """The name of the tool to be called.""" + + arguments: dict[str, Any] + """The arguments for the tool, as a dictionary.""" + + type: Literal["voice_stream_event_tool_call"] = "voice_stream_event_tool_call" + + VoiceStreamEvent: TypeAlias = Union[ - VoiceStreamEventAudio, VoiceStreamEventLifecycle, VoiceStreamEventError + VoiceStreamEventAudio, + VoiceStreamEventLifecycle, + VoiceStreamEventError, + VoiceStreamEventToolCall, ] """An event from the `VoicePipeline`, streamed via `StreamedAudioResult.stream()`.""" diff --git a/src/agents/voice/input.py b/src/agents/voice/input.py index 8613d27a..79212bb6 100644 --- a/src/agents/voice/input.py +++ b/src/agents/voice/input.py @@ -5,6 +5,7 @@ import io import wave from dataclasses import dataclass +from typing import List from ..exceptions import UserError from .imports import np, npt @@ -57,7 +58,9 @@ class AudioInput: def to_audio_file(self) -> tuple[str, io.BytesIO, str]: """Returns a tuple of (filename, bytes, content_type)""" - return _buffer_to_audio_file(self.buffer, self.frame_rate, self.sample_width, self.channels) + return _buffer_to_audio_file( + self.buffer, self.frame_rate, self.sample_width, self.channels + ) def to_base64(self) -> str: """Returns the audio data as a base64 encoded string.""" @@ -71,18 +74,86 @@ def to_base64(self) -> str: return base64.b64encode(self.buffer.tobytes()).decode("utf-8") -class StreamedAudioInput: - """Audio input represented as a stream of audio data. You can pass this to the `VoicePipeline` - and then push audio data into the queue using the `add_audio` method. +class StreamedAudioInput(AudioInput): + """An audio input that can be added to over time. + + This class is useful for continuous audio input, such as from a microphone. """ - def __init__(self): - self.queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = asyncio.Queue() + queue: asyncio.Queue + is_closed: bool + + def __init__(self) -> None: + """Initialize a new StreamedAudioInput.""" + self.queue = asyncio.Queue() + self.is_closed = False - async def add_audio(self, audio: npt.NDArray[np.int16 | np.float32]): - """Adds more audio data to the stream. + async def add_audio( + self, audio: npt.NDArray[npt.np.int16 | npt.np.float32] | None + ) -> None: + """Add audio to the input. Args: - audio: The audio data to add. Must be a numpy array of int16 or float32. + audio: The audio data to add. This can be a numpy array of int16 or float32 values. + NOTE: Passing None is deprecated and will be removed in a future version. + Use close() to signal the end of the stream instead. """ + if audio is None: + # Backwards compatibility: None was previously used as a sentinel + # to signal the end of the stream. Log a deprecation warning and call close() + # for backwards compatibility. + import warnings + + warnings.warn( + "Passing None to add_audio() is deprecated. Use close() to signal the end of the stream.", + DeprecationWarning, + stacklevel=2, + ) + await self.close() + return + + if self.is_closed: + raise ValueError("Cannot add audio to a closed StreamedAudioInput") + await self.queue.put(audio) + + async def close(self) -> None: + """Close the audio input stream. + + This signals that no more audio will be added and allows consumers to clean up resources. + After calling close(), add_audio() will raise an error. + """ + if not self.is_closed: + self.is_closed = True + # Put a sentinel None value for backwards compatibility + # TODO: Remove this in a future version when all consumers are updated + await self.queue.put(None) + + async def read_audio( + self, + ) -> List[npt.NDArray[npt.np.int16 | npt.np.float32]]: + """Read all audio from the input. + + Returns: + A list of numpy arrays containing the audio data. + """ + # Drain the queue to get all audio + result = [] + try: + while True: + item = self.queue.get_nowait() + if item is None: # Skip the sentinel value + continue + result.append(item) + self.queue.task_done() + except asyncio.QueueEmpty: + pass + + # If the queue is empty and no audio was found, wait for one item + if not result: + item = await self.queue.get() + if item is not None: # Skip the sentinel value + result.append(item) + self.queue.task_done() + + return result diff --git a/src/agents/voice/models/__init__.py b/src/agents/voice/models/__init__.py index e69de29b..90124277 100644 --- a/src/agents/voice/models/__init__.py +++ b/src/agents/voice/models/__init__.py @@ -0,0 +1,31 @@ +"""Voice model implementations (STT, TTS and realtime) bundled with the SDK. + +This file merely re-exports the concrete classes so downstream code can do: + +```python +from agents.voice.models import OpenAISTTModel, OpenAITTSModel +``` + +Historically this module existed; it was accidentally deleted during the +realtime-pipeline work. Restoring it keeps the public import surface +stable for existing applications that rely on the old path. +""" + +from __future__ import annotations + +__all__: list[str] = [ + # STT / TTS + "OpenAISTTModel", + "OpenAISTTTranscriptionSession", + "OpenAITTSModel", + # Realtime models + "OpenAIRealtimeLLM", + "SDKRealtimeLLM", +] + +# Speech-to-text / text-to-speech +from .openai_stt import OpenAISTTModel, OpenAISTTTranscriptionSession # noqa: E402 +from .openai_tts import OpenAITTSModel # noqa: E402 + +# Realtime voice LLMs +from .sdk_realtime import SDKRealtimeLLM # noqa: E402 diff --git a/src/agents/voice/models/sdk_realtime.py b/src/agents/voice/models/sdk_realtime.py new file mode 100644 index 00000000..0b647fe3 --- /dev/null +++ b/src/agents/voice/models/sdk_realtime.py @@ -0,0 +1,553 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import logging +from collections.abc import AsyncIterator, Sequence +from typing import Any, Dict, Optional, cast + +import numpy as np +import numpy.typing as npt +from openai import AsyncOpenAI +from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection + +from ..realtime.model import ( + RealtimeEvent, + RealtimeEventAudioChunk, + RealtimeEventError, + RealtimeEventSessionBegins, + RealtimeEventSessionEnds, + RealtimeEventTextDelta, + RealtimeEventToolCall, + RealtimeLLMModel, + RealtimeSession, + RealtimeEventResponseDone, + RealtimeEventRateLimitsUpdated, + RealtimeEventInputAudioTranscriptionDelta, + RealtimeEventInputAudioTranscriptionCompleted, +) +from ...exceptions import AgentsException, UserError +from ...logger import logger +from ...tool import Tool, FunctionTool + + +class SDKRealtimeSession(RealtimeSession): + """ + SDK-based implementation of RealtimeSession that uses the official OpenAI SDK. + """ + + _connection: AsyncRealtimeConnection + _event_queue: asyncio.Queue[RealtimeEvent | None] + _tools_by_name: dict[str, Tool] + _session_id: str | None = None + _is_connected: bool = False + _stop_event: asyncio.Event + _receiver_task: asyncio.Task | None = None + _accumulating_tool_args: dict[str, dict[str, str]] + + def __init__( + self, + connection: AsyncRealtimeConnection, + tools: Sequence[Tool], + ): + self._connection = connection + self._event_queue = asyncio.Queue() + self._tools_by_name = {tool.name: tool for tool in tools} + self._is_connected = True + self._stop_event = asyncio.Event() + self._receiver_task = None + self._accumulating_tool_args = {} + + async def start_receiver(self) -> None: + """Start the background receiver loop task.""" + if self._receiver_task is None: + self._receiver_task = asyncio.create_task(self._receiver_loop()) + + async def _receiver_loop(self) -> None: + """Process events from the SDK connection.""" + try: + self._is_connected = True + logger.info("Starting SDK receiver loop...") + + # Process the events from the connection + async for event in self._connection: + try: + logger.debug(f"Received SDK event type: {event.type}") + + if event.type == "session.created": + self._session_id = event.session.id + logger.info(f"Session created: {self._session_id}") + await self._event_queue.put( + RealtimeEventSessionBegins(session_id=self._session_id) + ) + + elif event.type == "session.updated": + logger.info(f"Session updated: {event}") + # Update our session settings if needed + + elif event.type == "response.audio.delta": + # Handle audio delta (base64 encoded) + audio_bytes = base64.b64decode(event.delta) + logger.debug( + f"Received audio delta, size: {len(audio_bytes)} bytes" + ) + await self._event_queue.put( + RealtimeEventAudioChunk(data=audio_bytes) + ) + + elif ( + event.type == "response.text.delta" + or event.type == "response.audio_transcript.delta" + or event.type + == "conversation.item.input_audio_transcription.delta" + ): + # Handle text delta + logger.debug(f"Received text delta: {event.delta}") + await self._event_queue.put( + RealtimeEventTextDelta(delta=event.delta) + ) + + elif ( + event.type + == "conversation.item.input_audio_transcription.delta" + ): + logger.info( + f"Received input audio transcription delta: {event}" + ) + # Incremental transcription for input audio + item_id = getattr(event, "item_id", None) + content_index = getattr(event, "content_index", 0) + delta = getattr(event, "delta", "") + await self._event_queue.put( + RealtimeEventInputAudioTranscriptionDelta( + item_id=item_id, + content_index=content_index, + delta=delta, + ) + ) + + elif ( + event.type + == "conversation.item.input_audio_transcription.completed" + ): + logger.info( + f"Received input audio transcription completed: {event}" + ) + # Completed transcription for input audio + item_id = getattr(event, "item_id", None) + content_index = getattr(event, "content_index", 0) + transcript = getattr(event, "transcript", "") + await self._event_queue.put( + RealtimeEventInputAudioTranscriptionCompleted( + item_id=item_id, + content_index=content_index, + transcript=transcript, + ) + ) + + elif ( + event.type == "response.output_item.added" + and getattr(event.item, "type", None) == "function_call" + ): + item_id = event.item.id + tool_name = event.item.name + server_call_id = getattr(event.item, "call_id", None) + if item_id and tool_name and server_call_id: + self._accumulating_tool_args[item_id] = { + "server_call_id": server_call_id, + "name": tool_name, + "args_str": "", + } + logger.info( + f"Starting to accumulate args for tool call: {tool_name} (item_id: {item_id}, server_call_id: {server_call_id})" + ) + else: + logger.warning( + f"Received function_call item without full details: item_id={item_id}, tool_name={tool_name}, server_call_id={server_call_id}" + ) + + elif event.type == "response.function_call_arguments.delta": + item_id = getattr(event, "item_id", None) + delta = getattr(event, "delta", "") + if ( + item_id + and item_id in self._accumulating_tool_args + and delta + ): + self._accumulating_tool_args[item_id]["args_str"] += delta + logger.debug( + f"Accumulating args for item {item_id}: partial_args='{self._accumulating_tool_args[item_id]['args_str']}'" + ) + + elif ( + event.type == "response.output_item.done" + and getattr(event.item, "type", None) == "function_call" + ): + item_id = event.item.id + if item_id and item_id in self._accumulating_tool_args: + tool_name = self._accumulating_tool_args[item_id]["name"] + args_str = self._accumulating_tool_args[item_id]["args_str"] + server_call_id = self._accumulating_tool_args[item_id][ + "server_call_id" + ] + logger.info( + f"Completed accumulating args for tool: {tool_name} (item_id: {item_id}, server_call_id: {server_call_id}), args_str: '{args_str}'" + ) + try: + arguments = json.loads(args_str) if args_str else {} + except json.JSONDecodeError: + logger.error( + f"JSONDecodeError for tool {tool_name} args: {args_str}" + ) + arguments = {} + + await self._event_queue.put( + RealtimeEventToolCall( + tool_call_id=server_call_id, + tool_name=tool_name, + arguments=arguments, + ) + ) + del self._accumulating_tool_args[item_id] + else: + logger.warning( + f"Received output_item.done for function_call but no accumulating args for item_id: {item_id}" + ) + + elif event.type == "tool.calls": + logger.info( + "Received 'tool.calls' event (may be separate from arg streaming)" + ) + for tool_call_sdk in event.tool_calls: + tool_id = ( + tool_call_sdk.id + ) # This is the server's call_id for the tool invocation + function = tool_call_sdk.function + tool_name = function.name + try: + arguments = json.loads( + function.arguments + ) # Expects fully formed JSON string + except json.JSONDecodeError: + arguments = {} + logger.info( + f" Processing from tool.calls: {tool_name} (id: {tool_id})" + ) + await self._event_queue.put( + RealtimeEventToolCall( + tool_call_id=tool_id, + tool_name=tool_name, + arguments=arguments, + ) + ) + + elif event.type == "session.ends": + # Handle session end + reason = getattr(event, "reason", "unknown") + logger.info(f"Session ended: {reason}") + await self._event_queue.put( + RealtimeEventSessionEnds(reason=reason) + ) + break # Exit receiver loop when session ends + + # Handle new specific events + elif event.type == "response.done": + item_id = getattr(event, "item_id", None) + logger.info(f"Response done for item_id: {item_id}") + await self._event_queue.put( + RealtimeEventResponseDone(item_id=item_id) + ) + + elif event.type == "rate_limits.updated": + # The SDK event for rate_limits.updated should have a 'data' field or similar + # For now, let's assume the event object itself contains the relevant data or can be serialized. + # The actual structure might be event.rate_limits or event.data.rate_limits + # We will pass the raw event or its relevant part to our RealtimeEvent + logger.info( + f"Rate limits updated: {event}" + ) # Log the whole event for now + try: + # Attempt to get a serializable representation for the event data + event_data_for_our_event = event.model_dump() + except AttributeError: + event_data_for_our_event = str( + event + ) # Fallback to string representation + await self._event_queue.put( + RealtimeEventRateLimitsUpdated( + data=event_data_for_our_event + ) + ) + + # Log other VAD-related and lifecycle events without putting them on main queue for now + elif event.type in [ + "input_audio_buffer.speech_started", + "input_audio_buffer.speech_stopped", + "input_audio_buffer.committed", + "conversation.item.created", + "response.created", + "response.output_item.added", # Might be useful if we track item lifecycles + "response.content_part.added", + "response.audio.done", # Specific part done, response.done is more holistic + "response.audio_transcript.done", + "response.content_part.done", + "response.output_item.done", + ]: + logger.info( + f"Received informational SDK event: {event.type} - {event}" + ) + + elif event.type == "error": + # Handle error event + error_message = getattr(event, "message", "Unknown error") + error_code = getattr(event, "code", None) + logger.error( + f"Error event: {error_message} (code: {error_code})" + ) + # Log the raw event data for more details + try: + raw_error_data = event.model_dump_json(indent=2) + logger.error(f"Raw error event data:\n{raw_error_data}") + except Exception as dump_err: + logger.error( + f"Could not dump raw error event data: {dump_err}" + ) + + await self._event_queue.put( + RealtimeEventError( + message=error_message, + code=error_code, + ) + ) + + else: + # Unknown event type + logger.warning(f"Unhandled event type: {event.type}") + + except asyncio.CancelledError: + logger.info("Receiver task cancelled") + break + + except Exception as e: + logger.error(f"Error processing event: {e}", exc_info=True) + await self._event_queue.put( + RealtimeEventError(message=f"Event processing error: {str(e)}") + ) + + # Check if we should stop + if self._stop_event.is_set(): + logger.info("Stop event set, exiting receiver loop") + break + + except asyncio.CancelledError: + logger.info("Receiver task cancelled") + + except Exception as e: + logger.error(f"Error in receiver loop: {e}", exc_info=True) + await self._event_queue.put( + RealtimeEventError(message=f"Receiver loop error: {str(e)}") + ) + + finally: + self._is_connected = False + await self._event_queue.put(None) # Signal end of events + logger.info("Receiver loop ended") + + async def send_audio_chunk(self, pcm_audio: npt.NDArray[np.int16]) -> None: + """Send an audio chunk to the realtime connection.""" + if not self._is_connected: + raise AgentsException("Connection is closed") + + try: + # Ensure the audio is the right format (int16) + if pcm_audio.dtype != np.int16: + raise UserError("Audio data must be np.int16 for RealtimeSession") + + # Convert to bytes and base64 encode + audio_bytes = pcm_audio.tobytes() + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + + # Send to the connection + await self._connection.input_audio_buffer.append(audio=audio_base64) + logger.debug(f"Sent audio chunk, size: {len(audio_bytes)} bytes") + + except Exception as e: + logger.error(f"Failed to send audio chunk: {e}", exc_info=True) + raise AgentsException(f"Failed to send audio chunk: {e}") from e + + async def send_tool_result(self, tool_call_id: str, content: str) -> None: + """Send a tool result to the realtime connection by creating a new conversation item.""" + if not self._is_connected: + raise AgentsException("Connection is closed") + + try: + # Construct the payload for a conversation.item.create event for a tool result + # The exact item.type for tool output needs to be confirmed from OpenAI docs + # Assuming "tool_output" or "tool_result_output" for now. + # Let's try "tool_output" as a common convention. + tool_result_payload = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": tool_call_id, + "output": content, + }, + } + await self._connection.send(tool_result_payload) + # Need to add a response.create event to trigger model output after tool call is completed. Ref: https://community.openai.com/t/realtime-api-tool-calling-problems-no-response-when-a-tool-is-included-in-the-session/966495/27 + await self._connection.send({"type": "response.create"}) + logger.info( + f"Sent tool result via conversation.item.create for call_id: {tool_call_id}" + ) + + except Exception as e: + logger.error(f"Failed to send tool result: {e}", exc_info=True) + raise AgentsException(f"Failed to send tool result: {e}") from e + + async def receive_events(self) -> AsyncIterator[RealtimeEvent]: + """Receive events from the session.""" + if self._receiver_task is None: + await self.start_receiver() + + while True: + event = await self._event_queue.get() + if event is None: # End of stream marker + break + yield event + + async def commit_audio_buffer(self) -> None: + """Commit the audio buffer when using manual turn detection.""" + try: + await self._connection.input_audio_buffer.commit() + logger.info("Committed audio buffer") + except Exception as e: + logger.error(f"Failed to commit audio buffer: {e}", exc_info=True) + raise AgentsException(f"Failed to commit audio buffer: {e}") from e + + async def cancel_response(self) -> None: + """Cancel the current response (for barge-in/interruptions).""" + try: + await self._connection.send({"type": "response.cancel"}) + logger.info("Cancelled response") + except Exception as e: + logger.error(f"Failed to cancel response: {e}", exc_info=True) + raise AgentsException(f"Failed to cancel response: {e}") from e + + async def close(self) -> None: + """Close the session logic, without closing the externally managed connection.""" + logger.info("Closing SDKRealtimeSession (receiver loop and tasks only).") + self._stop_event.set() + self._is_connected = False + + if self._receiver_task and not self._receiver_task.done(): + self._receiver_task.cancel() + try: + await self._receiver_task + except asyncio.CancelledError: + logger.info("SDKRealtimeSession receiver task cancelled during close.") + except Exception as e: + logger.error( + f"Error cancelling SDKRealtimeSession receiver task: {e}", + exc_info=True, + ) + + # Do NOT close self._connection here, it is managed by the pipeline's `async with` block. + logger.info("SDKRealtimeSession internal cleanup complete.") + + +class SDKRealtimeLLM(RealtimeLLMModel): + """RealtimeLLMModel implementation using the official OpenAI SDK.""" + + def __init__( + self, + model_name: str = "gpt-4o-realtime-preview", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + organization: Optional[str] = None, + ): + self._model = model_name + self._api_key = api_key + self._base_url = base_url + self._organization = organization + + @property + def model_name(self) -> str: + """Returns the name of the model.""" + return self._model + + async def create_session( + self, + connection: AsyncRealtimeConnection, + *, + tools: Sequence[Tool] = (), + system_message: str | None = None, + assistant_voice: str | None = None, + turn_detection: str | None = "server_vad", + ) -> RealtimeSession: + """Configure an existing SDK connection as a RealtimeSession.""" + try: + session = SDKRealtimeSession(connection, tools) + # Receiver loop should be started by the session itself if not already + # or by the pipeline orchestrator after session creation. + # For now, let's assume SDKRealtimeSession.start_receiver() is called after this. + + session_config_updates = {} + if turn_detection == "server_vad": + session_config_updates["turn_detection"] = {"type": "server_vad"} + elif turn_detection == "manual": + session_config_updates["turn_detection"] = None + + if tools: + # Manually format tools for the realtime API + realtime_tools = [] + for tool in tools: + if isinstance(tool, FunctionTool): + realtime_tools.append( + { + "type": "function", + "name": tool.name, + "description": tool.description, + "parameters": tool.params_json_schema, + # Note: Realtime API might not support 'strict' yet + } + ) + else: + # Log a warning for non-function tools if any + logger.warning( + f"Skipping non-FunctionTool '{getattr(tool, 'name', 'unknown')}' for realtime session." + ) + + if realtime_tools: + session_config_updates["tools"] = realtime_tools + + if system_message: + session_config_updates["instructions"] = system_message + if assistant_voice: + session_config_updates["voice"] = assistant_voice + + # enable input audio transcription + session_config_updates["input_audio_transcription"] = { + "language": "en", + "model": "gpt-4o-transcribe", + } + + # enable input audio noise reduction + session_config_updates["input_audio_noise_reduction"] = { + "type": "near_field", + } + + if session_config_updates: + await connection.session.update( + session=cast(Any, session_config_updates) + ) + + logger.info(f"SDKRealtimeSession configured for model: {self.model_name}") + await session.start_receiver() + return session + + except Exception as e: + logger.error(f"Failed to configure SDKRealtimeSession: {e}", exc_info=True) + raise AgentsException( + f"Failed to configure SDKRealtimeSession: {str(e)}" + ) from e diff --git a/src/agents/voice/pipeline_config.py b/src/agents/voice/pipeline_config.py index a4871612..2722d8e5 100644 --- a/src/agents/voice/pipeline_config.py +++ b/src/agents/voice/pipeline_config.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any +from typing import Any, Dict from ..tracing.util import gen_group_id from .model import STTModelSettings, TTSModelSettings, VoiceModelProvider @@ -44,3 +44,83 @@ class VoicePipelineConfig: tts_settings: TTSModelSettings = field(default_factory=TTSModelSettings) """The settings to use for the TTS model.""" + + # Settings for the new real-time pipeline + realtime_settings: dict[str, Any] = field(default_factory=dict) + """ + Settings specific to the RealtimeVoicePipeline. Can include things like: + - "system_message": str + - "assistant_voice": str (e.g., "alloy") + - Future client-side VAD thresholds or other model-specific params. + Example: {"system_message": "You are a helpful assistant.", "assistant_voice": "echo"} + """ + + def __init__( + self, + *, + trace_id: str | None = None, + trace_config: Dict[str, Any] | None = None, + upload_audio_data: bool = False, + trace_disable: bool = False, + model_provider: VoiceModelProvider | None = None, + tts_settings: Dict[str, Any] | None = None, + stt_settings: Dict[str, Any] | None = None, + realtime_settings: Dict[str, Any] | None = None, + group_id: str | None = None, + ): + """Initialize the voice pipeline configuration. + + Args: + trace_id: The trace ID to use for the trace. If None, one will be generated. + trace_config: The trace configuration to use for the trace. + upload_audio_data: Whether to upload audio data to the trace. + trace_disable: Whether to disable tracing. + model_provider: The model provider to use for the pipeline. + tts_settings: The settings to use for TTS models. + stt_settings: The settings to use for STT models. + realtime_settings: The settings to use for realtime models. + group_id: The grouping identifier to use for tracing. + """ + self.trace_id = trace_id + self.trace_config = trace_config + self.upload_audio_data = upload_audio_data + self.trace_disable = trace_disable + self.group_id = group_id or gen_group_id() + + # Normalise STT / TTS settings into their dataclass types + if isinstance(stt_settings, STTModelSettings): + self.stt_settings = stt_settings + else: # dict or None + self.stt_settings = STTModelSettings(**(stt_settings or {})) + + if isinstance(tts_settings, TTSModelSettings): + self.tts_settings = tts_settings + else: + self.tts_settings = TTSModelSettings(**(tts_settings or {})) + + # Initialize realtime settings with defaults if not provided + self._realtime_settings = { + # Default to server-side voice activity detection (VAD) + "turn_detection": "server_vad", + # Default voice for audio responses + "assistant_voice": "alloy", + } + + # Update with any provided settings + if realtime_settings: + self._realtime_settings.update(realtime_settings) + + # Backwards-compat: honour trace_disable flag but keep original attribute name + self.tracing_disabled = trace_disable + + # Ensure model_provider is always a valid provider instance + self.model_provider = model_provider or OpenAIVoiceModelProvider() + + @property + def realtime_settings(self) -> Dict[str, Any]: + """Get the realtime configuration settings. + + Returns: + A dictionary of realtime settings. + """ + return self._realtime_settings diff --git a/src/agents/voice/pipeline_realtime.py b/src/agents/voice/pipeline_realtime.py new file mode 100644 index 00000000..87508a25 --- /dev/null +++ b/src/agents/voice/pipeline_realtime.py @@ -0,0 +1,423 @@ +from __future__ import annotations + +import asyncio +import json +from collections.abc import Sequence +from typing import Any, List, Optional, Set, Tuple, Union, cast +import os +import numpy as np + +from ..exceptions import AgentsException, UserError +from .imports import npt +from ..logger import logger +from ..tool import Tool +from .input import StreamedAudioInput +from .pipeline_config import VoicePipelineConfig +from .result_realtime import StreamedRealtimeResult +from .realtime.model import ( + RealtimeEventError as LLMErrorEvent, + RealtimeEventToolCall as LLMToolCallEvent, + RealtimeEventSessionEnds, + RealtimeLLMModel, + RealtimeSession, +) +from .realtime.tool_exec import ToolExecutor + +# Import the new SDK-based implementation +from .models.sdk_realtime import SDKRealtimeLLM, SDKRealtimeSession +from openai import AsyncOpenAI + + +class RealtimeVoicePipeline: + """A voice agent pipeline for real-time, bidirectional audio and tool interaction with an LLM.""" + + def __init__( + self, + *, + model: RealtimeLLMModel | str | None = None, + tools: Sequence[Tool] = (), + config: VoicePipelineConfig | None = None, + shared_context: Any | None = None, + ): + """Create a new real-time voice pipeline. + + Args: + model: The real-time LLM model to use. Can be an instance of RealtimeLLMModel + or a string identifier for a model from the provider. + tools: A sequence of tools available to the LLM. + config: The pipeline configuration. If not provided, a default will be used. + shared_context: An optional context object that will be passed to tools when they are executed. + """ + if isinstance(model, str) or model is None: + self._model_name_to_load: str | None = model + self._model_instance: RealtimeLLMModel | None = None + elif isinstance(model, RealtimeLLMModel): + self._model_instance = model + self._model_name_to_load = None + else: + raise UserError( + f"Invalid type for model: {type(model)}. Expected RealtimeLLMModel or str." + ) + + self._tools = tools + self._config = config or VoicePipelineConfig() + self._shared_context = shared_context + self._tool_executor = ToolExecutor(tools, shared_context=shared_context) + + def _get_model(self) -> RealtimeLLMModel: + """Get the real-time LLM model to use.""" + if self._model_instance is None: + self._model_instance = get_realtime_llm_model( + self._model_name_to_load, self._config + ) + if self._model_instance is None: + raise AgentsException( + f"Failed to load real-time LLM model: {self._model_name_to_load or 'default'}" + ) + return self._model_instance + + async def _pump_audio_to_llm( + self, + audio_input: StreamedAudioInput, + llm_session: RealtimeSession, + ) -> None: + """Coroutine to continuously read from audio_input and send to LLM session. + + This method will continue pumping audio chunks until cancelled or the audio_input + is closed. It does not use a None sentinel to detect the end of audio. + """ + try: + # Start an infinite loop to process audio chunks as they arrive + while True: + try: + # Get the next audio chunk, this will block until a chunk is available + audio_chunk = await audio_input.queue.get() + + # Skip the None sentinel for backward compatibility + if audio_chunk is None: + logger.debug( + "Received None sentinel value (deprecated), ignoring" + ) + audio_input.queue.task_done() + continue + + # Check if we need to convert the audio format + if audio_chunk.dtype == np.float32: + audio_chunk_int16 = (audio_chunk * 32767).astype(np.int16) + elif audio_chunk.dtype == np.int16: + audio_chunk_int16 = audio_chunk + else: + logger.error( + f"Unsupported audio chunk dtype: {audio_chunk.dtype}" + ) + raise ValueError( + f"Unsupported audio chunk dtype: {audio_chunk.dtype}" + ) + + # Send the audio chunk to the LLM session + await llm_session.send_audio_chunk(audio_chunk_int16) + + # Mark the queue task as done + audio_input.queue.task_done() + + except asyncio.CancelledError: + # If we're cancelled, break out of the loop + logger.info("Audio pump task cancelled") + break + + except Exception as e: + # Log any errors but continue the loop to process the next chunk + logger.error(f"Error sending audio chunk: {e}", exc_info=True) + audio_input.queue.task_done() + + # Check if the input is closed + if audio_input.is_closed: + logger.info("Audio input closed, stopping audio pump") + break + + except asyncio.CancelledError: + logger.info("Audio pump task cancelled") + except Exception as e: + logger.error(f"Error in audio pump task: {e}", exc_info=True) + finally: + # Ensure we mark the queue as done for any pending tasks + try: + # Try to get any pending items from the queue and mark them as done + while True: + try: + audio_input.queue.get_nowait() + audio_input.queue.task_done() + except asyncio.QueueEmpty: + break + except Exception: + pass + + async def _handle_tool_call( + self, + event: LLMToolCallEvent, + llm_session: RealtimeSession, + result: StreamedRealtimeResult, + audio_input_queue: asyncio.Queue, + ) -> None: + """Execute a tool call and send the result back to the LLM.""" + try: + logger.info( + f"Handling tool call: {event.tool_name} (ID: {event.tool_call_id})" + ) + # Add the tool call event to the result stream + await result.push_llm_event(event) + + # Execute the tool and get the result + tool_output_content = await self._tool_executor.execute(event) + + # Send the result back to the LLM + await llm_session.send_tool_result(event.tool_call_id, tool_output_content) + logger.info( + f"Tool call {event.tool_name} (ID: {event.tool_call_id}) result sent." + ) + + except asyncio.CancelledError: + logger.info(f"Tool call handler for {event.tool_name} cancelled.") + except Exception as e: + logger.error( + f"Error handling tool call {event.tool_name} (ID: {event.tool_call_id}): {e}", + exc_info=True, + ) + # Try to send an error result back to the LLM + error_content = json.dumps({"error": str(e), "tool_name": event.tool_name}) + try: + await llm_session.send_tool_result(event.tool_call_id, error_content) + except Exception as send_err: + logger.error( + f"Failed to send error result for tool call {event.tool_call_id}: {send_err}" + ) + # Also push a general error to the result stream + await result.push_llm_event( + LLMErrorEvent( + message=f"Tool execution error for {event.tool_name}: {str(e)}" + ) + ) + + async def _consume_llm_events( + self, + llm_session: RealtimeSession, + result: StreamedRealtimeResult, + audio_input_queue: asyncio.Queue, + ) -> None: + """Continuously receive events from LLM and process them.""" + tool_call_tasks: set[asyncio.Task] = set() + try: + async for event in llm_session.receive_events(): + if isinstance(event, LLMToolCallEvent): + task = asyncio.create_task( + self._handle_tool_call( + event, llm_session, result, audio_input_queue + ) + ) + tool_call_tasks.add(task) + task.add_done_callback(tool_call_tasks.discard) + else: + # Push other events directly to the result stream + await result.push_llm_event(event) + + # If it's an error or session end event, break out of the loop + if isinstance(event, LLMErrorEvent) or isinstance( + event, RealtimeEventSessionEnds + ): + break + except asyncio.CancelledError: + logger.info("LLM event consumer task cancelled") + except Exception as e: + logger.error(f"Error in LLM event consumer task: {e}", exc_info=True) + await result.push_llm_event( + LLMErrorEvent(message=f"LLM event consumer error: {str(e)}") + ) + finally: + # Wait for any outstanding tool calls to complete + if tool_call_tasks: + logger.info( + f"Waiting for {len(tool_call_tasks)} outstanding tool call(s) to complete..." + ) + await asyncio.gather(*tool_call_tasks, return_exceptions=True) + logger.info("All outstanding tool calls completed") + + # Signal completion to the result stream + await result.signal_completion() + + async def run(self, audio_input: StreamedAudioInput) -> StreamedRealtimeResult: + """Run the real-time voice pipeline. + + Args: + audio_input: A StreamedAudioInput instance from which user audio is read. + The pipeline will continue to process audio from this input until + the pipeline is stopped or the audio_input is closed. + + Returns: + A StreamedRealtimeResult instance to stream events from the pipeline. + """ + model = self._get_model() + result = StreamedRealtimeResult(config=self._config) + + # Ensure the model instance is SDKRealtimeLLM for this orchestrator logic + if not isinstance(model, SDKRealtimeLLM): + raise UserError( + f"RealtimeVoicePipeline currently requires an SDKRealtimeLLM instance, got {type(model)}" + ) + + main_pipeline_task: asyncio.Task | None = None + + async def _pipeline_orchestrator(): + audio_pump_task: asyncio.Task | None = None + event_consumer_task: asyncio.Task | None = None + tool_call_tasks: set[asyncio.Task] = set() # Track tool call tasks + + # Store audio input queue for potential use in _handle_tool_call (for nudge) + audio_input_queue = audio_input.queue + + # Create the OpenAI SDK client + # API key, base_url, organization should be part of SDKRealtimeLLM's config + # and accessed here if needed, or SDKRealtimeLLM should expose a way to get client_kwargs + client_kwargs = {} + if ( + model._api_key + ): # Accessing protected member, consider a getter or passing config + client_kwargs["api_key"] = model._api_key + if model._base_url: # Accessing protected member + # The SDK's `AsyncOpenAI` client does not directly take `base_url` for realtime like websockets + # Instead, the `client.beta.realtime.connect()` might take it if supported, + # or it's part of the main client init for other APIs. + # For realtime, the endpoint is usually fixed or configured via env for the SDK. + # Let's assume for now the SDK handles this internally or via `model._base_url` if it were for the ws URI. + # The SDKRealtimeLLM.create_session no longer uses its own base_url for connect. + pass + if model._organization: # Accessing protected member + client_kwargs["organization"] = model._organization + + client = AsyncOpenAI(**client_kwargs) + + try: + logger.info( + f"Attempting to connect to OpenAI Realtime API with model: {model.model_name}" + ) + async with client.beta.realtime.connect( + model=model.model_name + ) as connection: + logger.info("Successfully connected to OpenAI Realtime API.") + + turn_detection_setting = self._config.realtime_settings.get( + "turn_detection", "server_vad" + ) + system_message_setting = self._config.realtime_settings.get( + "system_message" + ) + assistant_voice_setting = self._config.realtime_settings.get( + "assistant_voice" + ) + + # Create and configure the session using the active connection + llm_session = await model.create_session( + connection=connection, # Pass the active connection + tools=self._tools, + system_message=system_message_setting, + assistant_voice=assistant_voice_setting, + turn_detection=turn_detection_setting, + ) + + # Start receiver on the session now that it's fully configured + if isinstance(llm_session, SDKRealtimeSession): + await llm_session.start_receiver() + else: + # This case should ideally not happen if _get_model ensures SDKRealtimeLLM + logger.error( + "LLM session is not an SDKRealtimeSession, cannot start receiver." + ) + raise AgentsException( + "Invalid LLM session type for SDK pipeline." + ) + + logger.info( + f"Realtime LLM session configured: {llm_session._session_id if llm_session else 'N/A'}" + ) + + audio_pump_task = asyncio.create_task( + self._pump_audio_to_llm(audio_input, llm_session) + ) + event_consumer_task = asyncio.create_task( + self._consume_llm_events(llm_session, result, audio_input_queue) + ) + + done, pending = await asyncio.wait( + [audio_pump_task, event_consumer_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + task.cancel() + for task in done: + if task.exception(): + logger.error( + f"Exception in pipeline task: {task.exception()}", + exc_info=task.exception(), + ) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + except AgentsException as e: # Catch our specific exceptions + logger.error( + f"Agent-specific error during pipeline run: {e}", exc_info=True + ) + await result.push_llm_event(LLMErrorEvent(message=str(e))) + except Exception as e: + logger.error( + f"Unexpected error during RealtimeVoicePipeline run: {e}", + exc_info=True, + ) + await result.push_llm_event( + LLMErrorEvent(message=f"Pipeline run error: {str(e)}") + ) + finally: + active_tasks = [] + if audio_pump_task and not audio_pump_task.done(): + audio_pump_task.cancel() + active_tasks.append(audio_pump_task) + if event_consumer_task and not event_consumer_task.done(): + event_consumer_task.cancel() + active_tasks.append(event_consumer_task) + + if active_tasks: + await asyncio.gather(*active_tasks, return_exceptions=True) + + if llm_session: + # llm_session.close() now only handles internal tasks, not the SDK connection + await llm_session.close() + + # The `async with` block for `client.beta.realtime.connect` will handle closing the SDK connection. + logger.info("RealtimeVoicePipeline orchestrator finished.") + await result.signal_completion() + + main_pipeline_task = asyncio.create_task(_pipeline_orchestrator()) + result.set_pipeline_task(main_pipeline_task) + return result + + async def stop(self) -> None: + """Stop the pipeline and clean up resources.""" + # This method does nothing since the pipeline is managed + # by the context manager of the client.beta.realtime.connect() + # When the pipeline task is cancelled, it will clean up properly + pass + + +def get_realtime_llm_model( + model_name: str | None, config: VoicePipelineConfig +) -> RealtimeLLMModel: + # For now, this always returns SDKRealtimeLLM, ignoring provider logic in config for simplicity + # A more robust implementation would check config.model_provider + + # Retrieve API key and other necessary params from a secure config or env + # This is a placeholder; actual key management should be handled carefully. + api_key = os.environ.get("OPENAI_API_KEY") + # base_url and organization can also be retrieved if needed by SDKRealtimeLLM constructor + + return SDKRealtimeLLM( + model_name=model_name or SDKRealtimeLLM.DEFAULT_MODEL_NAME, api_key=api_key + ) diff --git a/src/agents/voice/realtime/__init__.py b/src/agents/voice/realtime/__init__.py new file mode 100644 index 00000000..c26a043d --- /dev/null +++ b/src/agents/voice/realtime/__init__.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 OpenAI (http://openai.com) +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions of this file are derived from prior work that is Copyright (c) 2024 OpenAI +# (http://openai.com) and licensed under the Apache License, Version 2.0. Other portions of this +# file are original work produced by an internal team at OpenAI and are licensed differently. +# Please see the LICENSE file directly in this repository for the full terms applicable to this file. + +"""Real-time voice interaction components.""" + +from .model import ( + RealtimeLLMModel, + RealtimeSession, + RealtimeEvent, + RealtimeEventSessionBegins, + RealtimeEventAudioChunk, + RealtimeEventTextDelta, + RealtimeEventToolCall, + RealtimeEventSessionEnds, + RealtimeEventError, +) +from .tool_exec import ToolExecutor + +__all__ = [ + "RealtimeLLMModel", + "RealtimeSession", + "RealtimeEvent", + "RealtimeEventSessionBegins", + "RealtimeEventAudioChunk", + "RealtimeEventTextDelta", + "RealtimeEventToolCall", + "RealtimeEventSessionEnds", + "RealtimeEventError", + "ToolExecutor", +] diff --git a/src/agents/voice/realtime/model.py b/src/agents/voice/realtime/model.py new file mode 100644 index 00000000..4a621b87 --- /dev/null +++ b/src/agents/voice/realtime/model.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import abc +from collections.abc import AsyncIterator, Sequence +from dataclasses import dataclass +from typing import Any, Literal + +from ..imports import npt +from ...tool import Tool + + +@dataclass +class RealtimeEventSessionBegins: + session_id: str + type: Literal["session_begins"] = "session_begins" + + +@dataclass +class RealtimeEventAudioChunk: + # Assuming audio is received as bytes (e.g., base64 decoded PCM) + # and will be converted to np.ndarray[np.int16] by the concrete model implementation + # or by the StreamedRealtimeResult + data: bytes + type: Literal["audio_chunk"] = "audio_chunk" + + +@dataclass +class RealtimeEventTextDelta: + delta: str + type: Literal["text_delta"] = "text_delta" + + +@dataclass +class RealtimeEventToolCall: + tool_call_id: str + tool_name: str + arguments: dict[str, Any] # Parsed JSON arguments + type: Literal["tool_call"] = "tool_call" + + +@dataclass +class RealtimeEventSessionEnds: + reason: str | None = None # Optional reason for session ending + type: Literal["session_ends"] = "session_ends" + + +@dataclass +class RealtimeEventError: + message: str + code: int | None = None # Optional error code + type: Literal["error"] = "error" + + +@dataclass +class RealtimeEventResponseDone: + item_id: str | None = None # The item ID of the response that is done + type: Literal["response_done"] = "response_done" + + +@dataclass +class RealtimeEventRateLimitsUpdated: + data: Any # The raw data from the rate_limits.updated event + type: Literal["rate_limits_updated"] = "rate_limits_updated" + + +@dataclass +class RealtimeEventInputAudioTranscriptionDelta: + item_id: str + content_index: int + delta: str + type: Literal["input_audio_transcription_delta"] = "input_audio_transcription_delta" + + +@dataclass +class RealtimeEventInputAudioTranscriptionCompleted: + item_id: str + content_index: int + transcript: str + type: Literal["input_audio_transcription_completed"] = ( + "input_audio_transcription_completed" + ) + + +RealtimeEvent = ( + RealtimeEventSessionBegins + | RealtimeEventAudioChunk + | RealtimeEventTextDelta + | RealtimeEventToolCall + | RealtimeEventSessionEnds + | RealtimeEventError + | RealtimeEventResponseDone + | RealtimeEventRateLimitsUpdated + | RealtimeEventInputAudioTranscriptionDelta + | RealtimeEventInputAudioTranscriptionCompleted +) + + +class RealtimeSession(abc.ABC): + """Represents an active real-time LLM session.""" + + @abc.abstractmethod + async def send_audio_chunk(self, pcm_audio: npt.NDArray[npt.np.int16]) -> None: + """Sends a chunk of PCM audio to the real-time LLM. + + Args: + pcm_audio: A numpy array of int16 audio data. + """ + pass + + @abc.abstractmethod + async def send_tool_result(self, tool_call_id: str, content: str) -> None: + """Sends the result of a tool execution back to the LLM. + + Args: + tool_call_id: The ID of the tool call this result corresponds to. + content: The string content of the tool's output (often JSON). + """ + pass + + @abc.abstractmethod + def receive_events(self) -> AsyncIterator[RealtimeEvent]: + """Receives and yields events from the real-time LLM session. + + Returns: + An async iterator of RealtimeEvent instances. + """ + # Ensure it's an async iterator + if False: # pragma: no cover + yield + + @abc.abstractmethod + async def close(self) -> None: + """Closes the real-time session and any underlying connections.""" + pass + + +class RealtimeLLMModel(abc.ABC): + """Abstract base class for real-time Language Model providers.""" + + @property + @abc.abstractmethod + def model_name(self) -> str: + """The name of the real-time LLM model (e.g., 'gpt-4o-realtime-preview').""" + pass + + @abc.abstractmethod + async def create_session( + self, + *, + tools: Sequence[Tool] = (), + system_message: str | None = None, + assistant_voice: str | None = None, + # Potentially other config like language, output_format, etc. + # For now, keeping it minimal as per OpenAI docs for gpt4o-realtime + ) -> RealtimeSession: + """Creates a new real-time LLM session. + + Args: + tools: A sequence of Tool instances available during the session. + system_message: An optional system message to guide the assistant. + assistant_voice: The voice to be used for the assistant's speech output. + (e.g., "alloy", "echo", etc. - specific to the model) + + Returns: + An instance of RealtimeSession. + """ + pass diff --git a/src/agents/voice/realtime/tool_exec.py b/src/agents/voice/realtime/tool_exec.py new file mode 100644 index 00000000..352203b7 --- /dev/null +++ b/src/agents/voice/realtime/tool_exec.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import json +import inspect +from collections.abc import Sequence +from typing import Any, Dict # Removed Set, get_type_hints, get_origin, get_args, Annotated + +from ...exceptions import AgentsException, UserError +from ...logger import logger +from ...run_context import RunContextWrapper +from ...tool import ( + FunctionTool, + Tool, +) +from .model import RealtimeEventToolCall + + +class ToolExecutor: + """Executes tools based on RealtimeEventToolCall events.""" + + def __init__(self, tools: Sequence[Tool], shared_context: Any | None = None): + self._tool_map: Dict[str, FunctionTool] = {} + self._shared_context = shared_context + # self._context_aware_tools: Set[str] = set() # Removed + + for tool in tools: + if isinstance(tool, FunctionTool): + self._tool_map[tool.name] = tool + # Removed context-awareness detection logic + else: # Tool is not a FunctionTool + logger.warning( + f"Tool '{tool.name}' is not a FunctionTool and will be ignored by ToolExecutor." + ) + + # logger.info(f"Final list of context-aware tools: {self._context_aware_tools}") # Removed + + async def execute(self, tool_call_event: RealtimeEventToolCall) -> str: + """Executes the specified tool and returns its string output. + + Args: + tool_call_event: The RealtimeEventToolCall describing the tool to execute. + + Returns: + A string representation of the tool's output (typically JSON). + """ + tool_name = tool_call_event.tool_name + tool = self._tool_map.get(tool_name) + + if not tool: + err_msg = f"Tool '{tool_name}' not found in ToolExecutor." + logger.error(err_msg) + return json.dumps({"error": err_msg, "tool_name": tool_name}) + + try: + arguments_json = json.dumps(tool_call_event.arguments) + except TypeError as e: # pragma: no cover + err_msg = f"Failed to serialize arguments for tool '{tool_name}': {e}" + logger.error(f"{err_msg} Arguments: {tool_call_event.arguments}") + return json.dumps({"error": err_msg, "tool_name": tool_name}) + + current_context_wrapper = RunContextWrapper(context=self._shared_context) + logger.info( + f"Executing tool: {tool_name} with args: {arguments_json}, providing RunContextWrapper." + ) + + try: + # Always pass RunContextWrapper, consistent with _run_impl.py + # FunctionTool itself is responsible for handling context for the user's function + tool_output = await tool.on_invoke_tool( + current_context_wrapper, arguments_json + ) + + if not isinstance(tool_output, str): + if isinstance(tool_output, (dict, list)): + tool_output_str = json.dumps(tool_output) + else: + tool_output_str = str(tool_output) + else: + tool_output_str = tool_output + + logger.info( + f"Tool {tool_name} executed successfully. Output length: {len(tool_output_str)}" + ) + return tool_output_str + except UserError as ue: # Specific error handling + logger.error(f"User error executing tool '{tool_name}': {ue}") + return json.dumps({"error": str(ue), "tool_name": tool_name, "error_type": "UserError"}) + except AgentsException as ae: # Specific error handling + logger.error(f"Agents framework error executing tool '{tool_name}': {ae}", exc_info=True) + return json.dumps({"error": str(ae), "tool_name": tool_name, "error_type": "AgentsException"}) + except Exception as e: # pragma: no cover + logger.error(f"Error executing tool '{tool_name}': {e}", exc_info=True) + return json.dumps({"error": str(e), "tool_name": tool_name, "error_type": "UnhandledException"}) diff --git a/src/agents/voice/result_realtime.py b/src/agents/voice/result_realtime.py new file mode 100644 index 00000000..fdc1589c --- /dev/null +++ b/src/agents/voice/result_realtime.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from ..exceptions import UserError, AgentsException +from .imports import np, npt +from ..logger import logger # Assuming logger is available +from .events import ( + VoiceStreamEvent, + VoiceStreamEventAudio, + VoiceStreamEventError, + VoiceStreamEventLifecycle, + VoiceStreamEventToolCall, +) +from .pipeline_config import ( + VoicePipelineConfig, +) # Assuming this might be needed for settings + + +class StreamedRealtimeResult: + """The output of a `RealtimeVoicePipeline`. Streams events directly from the real-time LLM session.""" + + def __init__( + self, + config: ( + VoicePipelineConfig | None + ) = None, # Optional config for future use or consistency + ): + self._event_queue: asyncio.Queue[VoiceStreamEvent | None] = ( + asyncio.Queue() + ) # None is sentinel for done + self._processing_task: asyncio.Task[Any] | None = ( + None # Task managing the flow from RealtimeSession to queue + ) + self._is_done = False + self._config = config or VoicePipelineConfig() # Use default if none provided + + async def _add_event(self, event: VoiceStreamEvent) -> None: + """Internal method to add an event to the outgoing queue.""" + if not self._is_done: + await self._event_queue.put(event) + + async def _add_realtime_llm_event( + self, llm_event: Any + ) -> None: # llm_event is RealtimeEvent from realtime.model + """Internal method to transform and add an event from the RealtimeLLM's session.""" + # This method will be called by the RealtimeVoicePipeline + if self._is_done: + return + + # Import RealtimeEvent types here to avoid circular dependency at module level if not careful + # However, as they are in a different module (realtime.model), it should be fine. + # For clarity, could also pass a converter function. + from .realtime.model import ( + RealtimeEventAudioChunk, + RealtimeEventError as LLMErrorEvent, + RealtimeEventSessionBegins, + RealtimeEventSessionEnds, + RealtimeEventTextDelta, + RealtimeEventToolCall as LLMToolCallEvent, + ) + + try: + if isinstance(llm_event, RealtimeEventAudioChunk): + # Convert bytes to np.ndarray[np.int16] + # Assuming audio is PCM 16-bit mono. Frame rate isn't carried here, assumed by consumer. + audio_np = np.frombuffer(llm_event.data, dtype=np.int16) + # VoiceStreamEventAudio expects npt.NDArray[np.int16 | np.float32] + # For consistency with existing STT/TTS, let's keep it as int16 for now. + # If float32 is needed, conversion would be: (audio_np.astype(np.float32) / 32767.0) + await self._event_queue.put(VoiceStreamEventAudio(data=audio_np)) + + elif isinstance(llm_event, RealtimeEventTextDelta): + # Currently, VoiceStreamEvent doesn't have a dedicated text delta event. + # For now, we can log it or decide if a new VoiceStreamEventText is needed. + # The primary output is voice and tool calls for this pipeline. + logger.debug(f"Realtime Text Delta: {llm_event.delta}") + # If we want to stream text alongside audio, a new event type would be added to voice.events + # e.g., VoiceStreamEventText(text=llm_event.delta) + # await self._event_queue.put(VoiceStreamEventText(text=llm_event.delta)) + pass # Ignoring for now as per initial plan focused on audio and tools + + elif isinstance(llm_event, LLMToolCallEvent): + await self._event_queue.put( + VoiceStreamEventToolCall( + tool_call_id=llm_event.tool_call_id, + tool_name=llm_event.tool_name, + arguments=llm_event.arguments, + ) + ) + elif isinstance(llm_event, RealtimeEventSessionBegins): + # This might translate to a 'turn_started' or a new specific event if needed. + # For now, let's signal turn_started as it's a session start. + await self._event_queue.put( + VoiceStreamEventLifecycle(event="turn_started") + ) + + elif isinstance(llm_event, RealtimeEventSessionEnds): + # Signal turn_ended and then session_ended. + await self._event_queue.put( + VoiceStreamEventLifecycle(event="turn_ended") + ) + await self._event_queue.put( + VoiceStreamEventLifecycle(event="session_ended") + ) + await self._done() # Mark as done internally + + elif isinstance(llm_event, LLMErrorEvent): + await self._event_queue.put( + VoiceStreamEventError( + error=AgentsException( + f"Realtime LLM Error: {llm_event.message} (Code: {llm_event.code})" + ) + ) + ) + await self._done() # Mark as done on error + + # Other RealtimeEvent types like SessionStatus could be handled here if needed. + + except Exception as e: # pragma: no cover + logger.error( + f"Error processing LLM event in StreamedRealtimeResult: {e}", + exc_info=True, + ) + await self._event_queue.put(VoiceStreamEventError(error=e)) + await self._done() + + async def _done(self) -> None: + """Signals that no more events will be added to the queue.""" + if not self._is_done: + self._is_done = True + await self._event_queue.put(None) # Sentinel to stop iteration + + def _set_processing_task(self, task: asyncio.Task[Any]) -> None: + """Sets the task that manages pulling events from the LLM and putting them into the queue.""" + self._processing_task = task + + async def stream(self) -> AsyncIterator[VoiceStreamEvent]: + """Streams events from the real-time voice pipeline. + + Yields: + VoiceStreamEvent: An event from the pipeline (audio, lifecycle, tool call, error). + """ + while True: + try: + event = await self._event_queue.get() + if event is None: # Sentinel indicating no more events + break + yield event + except asyncio.CancelledError: # pragma: no cover + logger.info("StreamedRealtimeResult stream cancelled.") + break + except Exception as e: # pragma: no cover + # This should ideally not be reached if errors are put on queue as VoiceStreamEventError + logger.error( + f"Unexpected error during StreamedRealtimeResult event streaming: {e}", + exc_info=True, + ) + # Yield a final error event if possible + try: + yield VoiceStreamEventError(error=e) + except Exception: # pragma: no cover + pass # Cannot yield anymore + break + + # Ensure the processing task is awaited if it exists, to propagate its exceptions + if self._processing_task: + try: + if not self._processing_task.done(): # pragma: no cover + # This might happen if stream() is exited early by the consumer + # Ensure task is cancelled if not done + self._processing_task.cancel() + await self._processing_task + except asyncio.CancelledError: # pragma: no cover + logger.info("StreamedRealtimeResult processing task was cancelled.") + except Exception as e: # pragma: no cover + # Errors from processing_task should have been put on the queue. + # This is a fallback. + logger.error( + f"Exception from StreamedRealtimeResult processing task: {e}", + exc_info=True, + ) + # If the queue is still accessible and not broken, try to put a final error + if not self._is_done: + try: + await self._event_queue.put(VoiceStreamEventError(error=e)) + await self._event_queue.put(None) # Sentinel + except Exception: # pragma: no cover + pass # Queue might be broken + + # Expose _add_realtime_llm_event and _done to be called by RealtimeVoicePipeline + # These are effectively the producer API for this result object. + # Renaming them for clarity when used by the pipeline. + async def push_llm_event(self, llm_event: Any) -> None: + await self._add_realtime_llm_event(llm_event) + + async def signal_completion(self) -> None: + await self._done() + + def set_pipeline_task(self, task: asyncio.Task[Any]) -> None: + self._set_processing_task(task) diff --git a/uv.lock b/uv.lock index 87dd3cd2..8b8ddeb0 100644 --- a/uv.lock +++ b/uv.lock @@ -1540,7 +1540,7 @@ requires-dist = [ { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, { name = "typing-extensions", specifier = ">=4.12.2,<5" }, - { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" }, + { name = "websockets", marker = "extra == 'voice'", specifier = ">=11.0,<16" }, ] provides-extras = ["voice", "viz", "litellm"]