Skip to content

Add keep_last_n_items filter to handoff_filters module #660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def stream_step_result_to_queue(
elif isinstance(item, HandoffCallItem):
event = RunItemStreamEvent(item=item, name="handoff_requested")
elif isinstance(item, HandoffOutputItem):
event = RunItemStreamEvent(item=item, name="handoff_occured")
event = RunItemStreamEvent(item=item, name="handoff_occurred")
elif isinstance(item, ToolCallItem):
event = RunItemStreamEvent(item=item, name="tool_called")
elif isinstance(item, ToolCallOutputItem):
Expand Down
2 changes: 1 addition & 1 deletion src/agents/agent_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def json_schema(self) -> dict[str, Any]:
@abc.abstractmethod
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
features, but guarantees valis JSON. See here for details:
features, but guarantees valid JSON. See here for details:
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
"""
pass
Expand Down
43 changes: 43 additions & 0 deletions src/agents/extensions/handoff_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,46 @@ def _remove_tool_types_from_input(
continue
filtered_items.append(item)
return tuple(filtered_items)


def keep_last_n_items(
handoff_input_data: HandoffInputData,
n: int,
keep_tool_messages: bool = True
) -> HandoffInputData:
"""
Keep only the last n items in the input history.
If keep_tool_messages is False, remove tool messages first.

Args:
handoff_input_data: The input data to filter
n: Number of items to keep from the end. Must be a positive integer.
If n is 1, only the last item is kept.
If n is greater than the number of items, all items are kept.
If n is less than or equal to 0, it raises a ValueError.
keep_tool_messages: If False, removes tool messages before filtering

Raises:
ValueError: If n is not a positive integer
"""
if not isinstance(n, int):
raise ValueError(f"n must be an integer, got {type(n).__name__}")
if n <= 0:
raise ValueError(f"n must be a positive integer, got {n}")

data = handoff_input_data
if not keep_tool_messages:
data = remove_all_tools(data)

# Always ensure input_history and new_items are tuples for consistent slicing and return
history = (
tuple(data.input_history)[-n:]
if isinstance(data.input_history, tuple)
else data.input_history
)

return HandoffInputData(
input_history=history,
pre_handoff_items=tuple(data.pre_handoff_items),
new_items=tuple(data.new_items),
)
2 changes: 1 addition & 1 deletion src/agents/stream_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class RunItemStreamEvent:
name: Literal[
"message_output_created",
"handoff_requested",
"handoff_occured",
"handoff_occurred",
"tool_called",
"tool_output",
"reasoning_item_created",
Expand Down
118 changes: 117 additions & 1 deletion tests/test_extension_filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from openai.types.responses import ResponseOutputMessage, ResponseOutputText

from agents import Agent, HandoffInputData
from agents.extensions.handoff_filters import remove_all_tools
from agents.extensions.handoff_filters import remove_all_tools, keep_last_n_items
from agents.items import (
HandoffOutputItem,
MessageOutputItem,
Expand Down Expand Up @@ -186,3 +187,118 @@ def test_removes_handoffs_from_history():
assert len(filtered_data.input_history) == 1
assert len(filtered_data.pre_handoff_items) == 1
assert len(filtered_data.new_items) == 1


def test_keep_last_n_items_basic():
"""Test the basic functionality of keep_last_n_items."""
handoff_input_data = HandoffInputData(
input_history=(
_get_message_input_item("Message 1"),
_get_message_input_item("Message 2"),
_get_message_input_item("Message 3"),
_get_message_input_item("Message 4"),
_get_message_input_item("Message 5"),
),
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
new_items=(_get_message_output_run_item("New item"),),
)

# Keep last 2 items
filtered_data = keep_last_n_items(handoff_input_data, 2)

assert len(filtered_data.input_history) == 2
assert filtered_data.input_history[-1] == _get_message_input_item("Message 5")
assert filtered_data.input_history[-2] == _get_message_input_item("Message 4")

# Pre-handoff and new items should remain unchanged
assert len(filtered_data.pre_handoff_items) == 1
assert len(filtered_data.new_items) == 1


def test_keep_last_n_items_with_tool_messages():
"""Test keeping last N items while removing tool messages."""
handoff_input_data = HandoffInputData(
input_history=(
_get_message_input_item("Message 1"),
_get_function_result_input_item("Function result"),
_get_message_input_item("Message 2"),
_get_handoff_input_item("Handoff"),
_get_message_input_item("Message 3"),
),
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
new_items=(_get_message_output_run_item("New item"),),
)

# Keep last 2 items but remove tool messages first
filtered_data = keep_last_n_items(handoff_input_data, 2, keep_tool_messages=False)

# Should have the last 2 non-tool messages
assert len(filtered_data.input_history) == 2
assert filtered_data.input_history[-1] == _get_message_input_item("Message 3")
assert filtered_data.input_history[-2] == _get_message_input_item("Message 2")


def test_keep_last_n_items_all():
"""Test keeping more items than exist."""
handoff_input_data = HandoffInputData(
input_history=(
_get_message_input_item("Message 1"),
_get_message_input_item("Message 2"),
),
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
new_items=(_get_message_output_run_item("New item"),),
)

# Request more items than exist
filtered_data = keep_last_n_items(handoff_input_data, 10)

# Should keep all items
assert len(filtered_data.input_history) == 2
assert filtered_data.input_history == handoff_input_data.input_history


def test_keep_last_n_items_with_string_history():
"""Test handling of string input_history."""
handoff_input_data = HandoffInputData(
input_history="This is a string history",
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
new_items=(_get_message_output_run_item("New item"),),
)

# String history should be preserved
filtered_data = keep_last_n_items(handoff_input_data, 3)

assert filtered_data.input_history == "This is a string history"


def test_keep_last_n_items_invalid_input():
"""Test error handling for invalid inputs."""
handoff_input_data = HandoffInputData(
input_history=(_get_message_input_item("Message 1"),),
pre_handoff_items=(),
new_items=(),
)

# Test with invalid n values
with pytest.raises(ValueError, match="n must be a positive integer"):
keep_last_n_items(handoff_input_data, 0)

with pytest.raises(ValueError, match="n must be a positive integer"):
keep_last_n_items(handoff_input_data, -5)

with pytest.raises(ValueError, match="n must be an integer"):
keep_last_n_items(handoff_input_data, "3")


def test_keep_last_n_items_empty_history():
"""Test with an empty input history."""
handoff_input_data = HandoffInputData(
input_history=(),
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
new_items=(_get_message_output_run_item("New item"),),
)

# Empty history should remain empty
filtered_data = keep_last_n_items(handoff_input_data, 3)

assert len(filtered_data.input_history) == 0