From 41bc4010335a4e0096b46a7d070348ff71d37b75 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Fri, 25 Apr 2025 12:38:37 -0300 Subject: [PATCH 01/13] feat(bedrock_agent): create bedrock agents functions data class --- .../utilities/data_classes/__init__.py | 2 + .../bedrock_agent_function_event.py | 109 ++++++++++++++++++ tests/events/bedrockAgentFunctionEvent.json | 33 ++++++ .../test_bedrock_agent_function_event.py | 94 +++++++++++++++ 4 files changed, 238 insertions(+) create mode 100644 aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py create mode 100644 tests/events/bedrockAgentFunctionEvent.json create mode 100644 tests/unit/data_classes/required_dependencies/test_bedrock_agent_function_event.py diff --git a/aws_lambda_powertools/utilities/data_classes/__init__.py b/aws_lambda_powertools/utilities/data_classes/__init__.py index 7c1b67e6fa0..da0ef655fea 100644 --- a/aws_lambda_powertools/utilities/data_classes/__init__.py +++ b/aws_lambda_powertools/utilities/data_classes/__init__.py @@ -9,6 +9,7 @@ from .appsync_resolver_events_event import AppSyncResolverEventsEvent from .aws_config_rule_event import AWSConfigRuleEvent from .bedrock_agent_event import BedrockAgentEvent +from .bedrock_agent_function_event import BedrockAgentFunctionEvent from .cloud_watch_alarm_event import ( CloudWatchAlarmConfiguration, CloudWatchAlarmData, @@ -59,6 +60,7 @@ "AppSyncResolverEventsEvent", "ALBEvent", "BedrockAgentEvent", + "BedrockAgentFunctionEvent", "CloudWatchAlarmData", "CloudWatchAlarmEvent", "CloudWatchAlarmMetric", diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py new file mode 100644 index 00000000000..69c48824ccc --- /dev/null +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import Any + +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper + + +class BedrockAgentInfo(DictWrapper): + @property + def name(self) -> str: + return self["name"] + + @property + def id(self) -> str: # noqa: A003 + return self["id"] + + @property + def alias(self) -> str: + return self["alias"] + + @property + def version(self) -> str: + return self["version"] + + +class BedrockAgentFunctionParameter(DictWrapper): + @property + def name(self) -> str: + return self["name"] + + @property + def type(self) -> str: # noqa: A003 + return self["type"] + + @property + def value(self) -> str: + return self["value"] + + +class BedrockAgentFunctionEvent(DictWrapper): + """ + Bedrock Agent Function input event + + Documentation: + https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html + """ + + @classmethod + def validate_required_fields(cls, data: dict[str, Any]) -> None: + required_fields = { + "messageVersion": str, + "agent": dict, + "inputText": str, + "sessionId": str, + "actionGroup": str, + "function": str, + } + + for field, field_type in required_fields.items(): + if field not in data: + raise ValueError(f"Missing required field: {field}") + if not isinstance(data[field], field_type): + raise TypeError(f"Field {field} must be of type {field_type}") + + # Validate agent structure + required_agent_fields = {"name", "id", "alias", "version"} + if not all(field in data["agent"] for field in required_agent_fields): + raise ValueError("Agent object missing required fields") + + def __init__(self, data: dict[str, Any]) -> None: + super().__init__(data) + self.validate_required_fields(data) + + @property + def message_version(self) -> str: + return self["messageVersion"] + + @property + def input_text(self) -> str: + return self["inputText"] + + @property + def session_id(self) -> str: + return self["sessionId"] + + @property + def action_group(self) -> str: + return self["actionGroup"] + + @property + def function(self) -> str: + return self["function"] + + @property + def parameters(self) -> list[BedrockAgentFunctionParameter]: + parameters = self.get("parameters") or [] + return [BedrockAgentFunctionParameter(x) for x in parameters] + + @property + def agent(self) -> BedrockAgentInfo: + return BedrockAgentInfo(self["agent"]) + + @property + def session_attributes(self) -> dict[str, str]: + return self.get("sessionAttributes", {}) or {} + + @property + def prompt_session_attributes(self) -> dict[str, str]: + return self.get("promptSessionAttributes", {}) or {} diff --git a/tests/events/bedrockAgentFunctionEvent.json b/tests/events/bedrockAgentFunctionEvent.json new file mode 100644 index 00000000000..043b4226595 --- /dev/null +++ b/tests/events/bedrockAgentFunctionEvent.json @@ -0,0 +1,33 @@ +{ + "messageVersion": "1.0", + "agent": { + "alias": "PROD", + "name": "hr-assistant-function-def", + "version": "1", + "id": "1234abcd" + }, + "sessionId": "123456789123458", + "sessionAttributes": { + "employeeId": "EMP123", + "department": "Engineering" + }, + "promptSessionAttributes": { + "lastInteraction": "2024-02-01T15:30:00Z", + "requestType": "vacation" + }, + "inputText": "I want to request vacation from March 15 to March 20", + "actionGroup": "VacationsActionGroup", + "function": "submitVacationRequest", + "parameters": [ + { + "name": "startDate", + "type": "string", + "value": "2024-03-15" + }, + { + "name": "endDate", + "type": "string", + "value": "2024-03-20" + } + ] +} \ No newline at end of file diff --git a/tests/unit/data_classes/required_dependencies/test_bedrock_agent_function_event.py b/tests/unit/data_classes/required_dependencies/test_bedrock_agent_function_event.py new file mode 100644 index 00000000000..cd41fdd2e4b --- /dev/null +++ b/tests/unit/data_classes/required_dependencies/test_bedrock_agent_function_event.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import pytest + +from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent +from tests.functional.utils import load_event + + +def test_bedrock_agent_function_event(): + raw_event = load_event("bedrockAgentFunctionEvent.json") + parsed_event = BedrockAgentFunctionEvent(raw_event) + + # Test basic event properties + assert parsed_event.message_version == raw_event["messageVersion"] + assert parsed_event.session_id == raw_event["sessionId"] + assert parsed_event.input_text == raw_event["inputText"] + assert parsed_event.action_group == raw_event["actionGroup"] + assert parsed_event.function == raw_event["function"] + + # Test agent information + agent = parsed_event.agent + raw_agent = raw_event["agent"] + assert agent.alias == raw_agent["alias"] + assert agent.name == raw_agent["name"] + assert agent.version == raw_agent["version"] + assert agent.id == raw_agent["id"] + + # Test session attributes + assert parsed_event.session_attributes == raw_event["sessionAttributes"] + assert parsed_event.prompt_session_attributes == raw_event["promptSessionAttributes"] + + # Test parameters + parameters = parsed_event.parameters + raw_parameters = raw_event["parameters"] + assert len(parameters) == len(raw_parameters) + + for param, raw_param in zip(parameters, raw_parameters): + assert param.name == raw_param["name"] + assert param.type == raw_param["type"] + assert param.value == raw_param["value"] + + +def test_bedrock_agent_function_event_minimal(): + """Test with minimal required fields""" + minimal_event = { + "messageVersion": "1.0", + "agent": { + "alias": "PROD", + "name": "hr-assistant-function-def", + "version": "1", + "id": "1234abcd-56ef-78gh-90ij-klmn12345678", + }, + "sessionId": "87654321-abcd-efgh-ijkl-mnop12345678", + "inputText": "I want to request vacation", + "actionGroup": "VacationsActionGroup", + "function": "submitVacationRequest", + } + + parsed_event = BedrockAgentFunctionEvent(minimal_event) + + assert parsed_event.session_attributes == {} + assert parsed_event.prompt_session_attributes == {} + assert parsed_event.parameters == [] + + +def test_bedrock_agent_function_event_validation(): + """Test validation of required fields""" + # Test missing required field + with pytest.raises(ValueError, match="Missing required field: messageVersion"): + BedrockAgentFunctionEvent({}) + + # Test invalid field type + invalid_event = { + "messageVersion": 1, # should be string + "agent": {"alias": "PROD", "name": "hr-assistant", "version": "1", "id": "1234"}, + "inputText": "", + "sessionId": "", + "actionGroup": "", + "function": "", + } + with pytest.raises(TypeError, match="Field messageVersion must be of type "): + BedrockAgentFunctionEvent(invalid_event) + + # Test missing agent fields + invalid_agent_event = { + "messageVersion": "1.0", + "agent": {"name": "test"}, # missing required agent fields + "inputText": "", + "sessionId": "", + "actionGroup": "", + "function": "", + } + with pytest.raises(ValueError, match="Agent object missing required fields"): + BedrockAgentFunctionEvent(invalid_agent_event) From bed8f3f3df28941a9e24c9dc3da4498cfd0e4fd2 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Fri, 25 Apr 2025 16:21:35 -0300 Subject: [PATCH 02/13] create resolver --- .../event_handler/__init__.py | 2 + .../event_handler/bedrock_agent_function.py | 102 ++++++++++++++++ .../test_bedrock_agent_functions.py | 109 ++++++++++++++++++ 3 files changed, 213 insertions(+) create mode 100644 aws_lambda_powertools/event_handler/bedrock_agent_function.py create mode 100644 tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 8bcf2d6636c..db5830d0288 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -12,6 +12,7 @@ ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver +from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, @@ -26,6 +27,7 @@ "ALBResolver", "ApiGatewayResolver", "BedrockAgentResolver", + "BedrockAgentFunctionResolver", "CORSConfig", "LambdaFunctionUrlResolver", "Response", diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py new file mode 100644 index 00000000000..29c95b8c38e --- /dev/null +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable + +from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent + + +class BedrockAgentFunctionResolver: + """Bedrock Agent Function resolver that handles function definitions + + Examples + -------- + Simple example with a custom lambda handler + + ```python + from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver + + app = BedrockAgentFunctionResolver() + + @app.tool(description="Gets the current UTC time") + def get_current_time(): + from datetime import datetime + return datetime.utcnow().isoformat() + + def lambda_handler(event, context): + return app.resolve(event, context) + ``` + """ + def __init__(self) -> None: + self._tools: dict[str, dict[str, Any]] = {} + self.current_event: BedrockAgentFunctionEvent | None = None + + def tool(self, description: str | None = None) -> Callable: + """Decorator to register a tool function""" + def decorator(func: Callable) -> Callable: + if not description: + raise ValueError("Tool description is required") + + function_name = func.__name__ + if function_name in self._tools: + raise ValueError(f"Tool '{function_name}' already registered") + + self._tools[function_name] = { + "function": func, + "description": description, + } + return func + return decorator + + def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]: + """Resolves the function call from Bedrock Agent event""" + try: + self.current_event = BedrockAgentFunctionEvent(event) + return self._resolve() + except KeyError as e: + raise ValueError(f"Missing required field: {str(e)}") + + def _resolve(self) -> dict[str, Any]: + """Internal resolution logic""" + function_name = self.current_event.function + action_group = self.current_event.action_group + + if function_name not in self._tools: + return self._create_response( + action_group=action_group, + function_name=function_name, + result=f"Function not found: {function_name}" + ) + + try: + result = self._tools[function_name]["function"]() + return self._create_response( + action_group=action_group, + function_name=function_name, + result=result + ) + except Exception as e: + return self._create_response( + action_group=action_group, + function_name=function_name, + result=f"Error: {str(e)}" + ) + + def _create_response(self, action_group: str, function_name: str, result: Any) -> dict[str, Any]: + """Create response in Bedrock Agent format""" + return { + "messageVersion": "1.0", + "response": { + "actionGroup": action_group, + "function": function_name, + "functionResponse": { + "responseBody": { + "TEXT": { + "body": str(result) + } + } + } + } + } \ No newline at end of file diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py new file mode 100644 index 00000000000..13308d1c24a --- /dev/null +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import pytest +from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver +from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent +from tests.functional.utils import load_event + + +def test_bedrock_agent_function(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(description="Gets the current time") + def get_current_time(): + assert isinstance(app.current_event, BedrockAgentFunctionEvent) + return "2024-02-01T12:00:00Z" + + # WHEN calling the event handler + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "get_current_time" # ensure function name matches + result = app.resolve(raw_event, {}) + + # THEN process event correctly + assert result["messageVersion"] == "1.0" + assert result["response"]["actionGroup"] == raw_event["actionGroup"] + assert result["response"]["function"] == "get_current_time" + assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "2024-02-01T12:00:00Z" + + +def test_bedrock_agent_function_with_error(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(description="Function that raises error") + def error_function(): + raise ValueError("Something went wrong") + + # WHEN calling the event handler with a function that raises an error + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "error_function" + result = app.resolve(raw_event, {}) + + # THEN process the error correctly + assert result["messageVersion"] == "1.0" + assert result["response"]["actionGroup"] == raw_event["actionGroup"] + assert result["response"]["function"] == "error_function" + assert "Error: Something went wrong" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] + + +def test_bedrock_agent_function_not_found(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(description="Test function") + def test_function(): + return "test" + + # WHEN calling the event handler with a non-existent function + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "nonexistent_function" + result = app.resolve(raw_event, {}) + + # THEN return function not found response + assert result["messageVersion"] == "1.0" + assert result["response"]["actionGroup"] == raw_event["actionGroup"] + assert result["response"]["function"] == "nonexistent_function" + assert "Function not found" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] + + +def test_bedrock_agent_function_missing_description(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + # WHEN registering a tool without description + # THEN raise ValueError + with pytest.raises(ValueError, match="Tool description is required"): + @app.tool() + def test_function(): + return "test" + + +def test_bedrock_agent_function_duplicate_registration(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + # WHEN registering the same function twice + @app.tool(description="First registration") + def test_function(): + return "test" + + # THEN raise ValueError on second registration + with pytest.raises(ValueError, match="Tool 'test_function' already registered"): + @app.tool(description="Second registration") + def test_function(): # noqa: F811 + return "test" + + +def test_bedrock_agent_function_invalid_event(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(description="Test function") + def test_function(): + return "test" + + # WHEN calling with invalid event + # THEN raise ValueError + with pytest.raises(ValueError, match="Missing required field"): + app.resolve({}, {}) \ No newline at end of file From a3765f0abc7b05aa8c512bf6e625b288ef4b4012 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Fri, 25 Apr 2025 16:26:57 -0300 Subject: [PATCH 03/13] mypy --- .../event_handler/bedrock_agent_function.py | 28 ++++++++----------- .../test_bedrock_agent_functions.py | 5 +++- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py index 29c95b8c38e..8849dbe01b6 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent_function.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -29,12 +29,14 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ + def __init__(self) -> None: self._tools: dict[str, dict[str, Any]] = {} self.current_event: BedrockAgentFunctionEvent | None = None def tool(self, description: str | None = None) -> Callable: """Decorator to register a tool function""" + def decorator(func: Callable) -> Callable: if not description: raise ValueError("Tool description is required") @@ -48,6 +50,7 @@ def decorator(func: Callable) -> Callable: "description": description, } return func + return decorator def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]: @@ -60,6 +63,9 @@ def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]: def _resolve(self) -> dict[str, Any]: """Internal resolution logic""" + if self.current_event is None: + raise ValueError("No event to process") + function_name = self.current_event.function action_group = self.current_event.action_group @@ -67,21 +73,17 @@ def _resolve(self) -> dict[str, Any]: return self._create_response( action_group=action_group, function_name=function_name, - result=f"Function not found: {function_name}" + result=f"Function not found: {function_name}", ) try: result = self._tools[function_name]["function"]() - return self._create_response( - action_group=action_group, - function_name=function_name, - result=result - ) + return self._create_response(action_group=action_group, function_name=function_name, result=result) except Exception as e: return self._create_response( action_group=action_group, function_name=function_name, - result=f"Error: {str(e)}" + result=f"Error: {str(e)}", ) def _create_response(self, action_group: str, function_name: str, result: Any) -> dict[str, Any]: @@ -91,12 +93,6 @@ def _create_response(self, action_group: str, function_name: str, result: Any) - "response": { "actionGroup": action_group, "function": function_name, - "functionResponse": { - "responseBody": { - "TEXT": { - "body": str(result) - } - } - } - } - } \ No newline at end of file + "functionResponse": {"responseBody": {"TEXT": {"body": str(result)}}}, + }, + } diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index 13308d1c24a..937fe298d35 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest + from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent from tests.functional.utils import load_event @@ -74,6 +75,7 @@ def test_bedrock_agent_function_missing_description(): # WHEN registering a tool without description # THEN raise ValueError with pytest.raises(ValueError, match="Tool description is required"): + @app.tool() def test_function(): return "test" @@ -90,6 +92,7 @@ def test_function(): # THEN raise ValueError on second registration with pytest.raises(ValueError, match="Tool 'test_function' already registered"): + @app.tool(description="Second registration") def test_function(): # noqa: F811 return "test" @@ -106,4 +109,4 @@ def test_function(): # WHEN calling with invalid event # THEN raise ValueError with pytest.raises(ValueError, match="Missing required field"): - app.resolve({}, {}) \ No newline at end of file + app.resolve({}, {}) From 44d80f8a32821e252be15ea300f327bb76beb1dc Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Fri, 25 Apr 2025 17:14:35 -0300 Subject: [PATCH 04/13] add response --- .../event_handler/bedrock_agent_function.py | 92 ++++++++++++++----- .../test_bedrock_agent_functions.py | 74 +++++++++------ 2 files changed, 113 insertions(+), 53 deletions(-) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py index 8849dbe01b6..e750199592a 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent_function.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -2,19 +2,64 @@ from typing import TYPE_CHECKING, Any +from typing_extensions import override + +from aws_lambda_powertools.event_handler.api_gateway import Response, ResponseBuilder + if TYPE_CHECKING: from collections.abc import Callable +from enum import Enum + from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent +class ResponseState(Enum): + FAILURE = "FAILURE" + REPROMPT = "REPROMPT" + + +class BedrockFunctionsResponseBuilder(ResponseBuilder): + """ + Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda + when using Bedrock Agent Functions. + + Since the payload format is different from the standard API Gateway Proxy event, + we override the build method. + """ + + @override + def build(self, event: BedrockAgentFunctionEvent, *args) -> dict[str, Any]: + """Build the full response dict to be returned by the lambda""" + self._route(event, None) + + body = self.response.body + if self.response.is_json() and not isinstance(self.response.body, str): + body = self.serializer(body) + + response: dict[str, Any] = { + "messageVersion": "1.0", + "response": { + "actionGroup": event.action_group, + "function": event.function, + "functionResponse": {"responseBody": {"TEXT": {"body": str(body)}}}, + }, + } + + # Add responseState if it's an error + if self.response.status_code >= 400: + response["response"]["functionResponse"]["responseState"] = ( + ResponseState.REPROMPT.value if self.response.status_code == 400 else ResponseState.FAILURE.value + ) + + return response + + class BedrockAgentFunctionResolver: """Bedrock Agent Function resolver that handles function definitions Examples -------- - Simple example with a custom lambda handler - ```python from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver @@ -33,6 +78,7 @@ def lambda_handler(event, context): def __init__(self) -> None: self._tools: dict[str, dict[str, Any]] = {} self.current_event: BedrockAgentFunctionEvent | None = None + self._response_builder_class = BedrockFunctionsResponseBuilder def tool(self, description: str | None = None) -> Callable: """Decorator to register a tool function""" @@ -67,32 +113,28 @@ def _resolve(self) -> dict[str, Any]: raise ValueError("No event to process") function_name = self.current_event.function - action_group = self.current_event.action_group if function_name not in self._tools: - return self._create_response( - action_group=action_group, - function_name=function_name, - result=f"Function not found: {function_name}", - ) + return self._response_builder_class( + Response( + status_code=400, # Using 400 to trigger REPROMPT + body=f"Function not found: {function_name}", + ), + ).build(self.current_event) try: result = self._tools[function_name]["function"]() - return self._create_response(action_group=action_group, function_name=function_name, result=result) + # Always wrap the result in a Response object + if not isinstance(result, Response): + result = Response( + status_code=200, # Success + body=result, + ) + return self._response_builder_class(result).build(self.current_event) except Exception as e: - return self._create_response( - action_group=action_group, - function_name=function_name, - result=f"Error: {str(e)}", - ) - - def _create_response(self, action_group: str, function_name: str, result: Any) -> dict[str, Any]: - """Create response in Bedrock Agent format""" - return { - "messageVersion": "1.0", - "response": { - "actionGroup": action_group, - "function": function_name, - "functionResponse": {"responseBody": {"TEXT": {"body": str(result)}}}, - }, - } + return self._response_builder_class( + Response( + status_code=500, # Using 500 to trigger FAILURE + body=f"Error: {str(e)}", + ), + ).build(self.current_event) diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index 937fe298d35..c409d504231 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -1,31 +1,34 @@ from __future__ import annotations +import json + import pytest -from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver +from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, Response, content_types from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent from tests.functional.utils import load_event -def test_bedrock_agent_function(): +def test_bedrock_agent_function_with_string_response(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() - @app.tool(description="Gets the current time") - def get_current_time(): + @app.tool(description="Returns a string") + def test_function(): assert isinstance(app.current_event, BedrockAgentFunctionEvent) - return "2024-02-01T12:00:00Z" + return "Hello from string" # WHEN calling the event handler raw_event = load_event("bedrockAgentFunctionEvent.json") - raw_event["function"] = "get_current_time" # ensure function name matches + raw_event["function"] = "test_function" result = app.resolve(raw_event, {}) - # THEN process event correctly + # THEN process event correctly with string response assert result["messageVersion"] == "1.0" assert result["response"]["actionGroup"] == raw_event["actionGroup"] - assert result["response"]["function"] == "get_current_time" - assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "2024-02-01T12:00:00Z" + assert result["response"]["function"] == "test_function" + assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello from string" + assert "responseState" not in result["response"]["functionResponse"] # Success has no state def test_bedrock_agent_function_with_error(): @@ -46,29 +49,53 @@ def error_function(): assert result["response"]["actionGroup"] == raw_event["actionGroup"] assert result["response"]["function"] == "error_function" assert "Error: Something went wrong" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] + assert result["response"]["functionResponse"]["responseState"] == "FAILURE" def test_bedrock_agent_function_not_found(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() - @app.tool(description="Test function") - def test_function(): - return "test" - # WHEN calling the event handler with a non-existent function raw_event = load_event("bedrockAgentFunctionEvent.json") raw_event["function"] = "nonexistent_function" result = app.resolve(raw_event, {}) - # THEN return function not found response + # THEN return function not found response with REPROMPT state assert result["messageVersion"] == "1.0" assert result["response"]["actionGroup"] == raw_event["actionGroup"] assert result["response"]["function"] == "nonexistent_function" assert "Function not found" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] + assert result["response"]["functionResponse"]["responseState"] == "REPROMPT" -def test_bedrock_agent_function_missing_description(): +def test_bedrock_agent_function_with_response_object(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(description="Returns a Response object") + def test_function(): + return Response( + status_code=200, + content_type=content_types.APPLICATION_JSON, + body={"message": "Hello from Response"}, + ) + + # WHEN calling the event handler + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "test_function" + result = app.resolve(raw_event, {}) + + # THEN process event correctly with Response object + assert result["messageVersion"] == "1.0" + assert result["response"]["actionGroup"] == raw_event["actionGroup"] + assert result["response"]["function"] == "test_function" + response_body = result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] + assert json.loads(response_body) == {"message": "Hello from Response"} + assert "responseState" not in result["response"]["functionResponse"] # Success has no state + + +def test_bedrock_agent_function_registration(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() @@ -80,21 +107,16 @@ def test_bedrock_agent_function_missing_description(): def test_function(): return "test" - -def test_bedrock_agent_function_duplicate_registration(): - # GIVEN a Bedrock Agent Function resolver - app = BedrockAgentFunctionResolver() - # WHEN registering the same function twice + # THEN raise ValueError @app.tool(description="First registration") - def test_function(): + def duplicate_function(): return "test" - # THEN raise ValueError on second registration - with pytest.raises(ValueError, match="Tool 'test_function' already registered"): + with pytest.raises(ValueError, match="Tool 'duplicate_function' already registered"): @app.tool(description="Second registration") - def test_function(): # noqa: F811 + def duplicate_function(): # noqa: F811 return "test" @@ -102,10 +124,6 @@ def test_bedrock_agent_function_invalid_event(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() - @app.tool(description="Test function") - def test_function(): - return "test" - # WHEN calling with invalid event # THEN raise ValueError with pytest.raises(ValueError, match="Missing required field"): From abbc1004f019425aafab2a77f75f52539f630e7e Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Mon, 28 Apr 2025 16:35:24 -0300 Subject: [PATCH 05/13] add name param to tool --- .../event_handler/bedrock_agent_function.py | 18 ++++++++++++++--- .../test_bedrock_agent_functions.py | 20 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py index e750199592a..21c5c824ffd 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent_function.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -80,14 +80,26 @@ def __init__(self) -> None: self.current_event: BedrockAgentFunctionEvent | None = None self._response_builder_class = BedrockFunctionsResponseBuilder - def tool(self, description: str | None = None) -> Callable: - """Decorator to register a tool function""" + def tool( + self, + description: str | None = None, + name: str | None = None, + ) -> Callable: + """Decorator to register a tool function + + Parameters + ---------- + description : str | None + Description of what the tool does + name : str | None + Custom name for the tool. If not provided, uses the function name + """ def decorator(func: Callable) -> Callable: if not description: raise ValueError("Tool description is required") - function_name = func.__name__ + function_name = name or func.__name__ if function_name in self._tools: raise ValueError(f"Tool '{function_name}' already registered") diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index c409d504231..71f1c852913 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -128,3 +128,23 @@ def test_bedrock_agent_function_invalid_event(): # THEN raise ValueError with pytest.raises(ValueError, match="Missing required field"): app.resolve({}, {}) + + +def test_bedrock_agent_function_with_custom_name(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(name="customName", description="Function with custom name") + def test_function(): + return "Hello from custom named function" + + # WHEN calling the event handler + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "customName" # Use custom name instead of function name + result = app.resolve(raw_event, {}) + + # THEN process event correctly + assert result["messageVersion"] == "1.0" + assert result["response"]["actionGroup"] == raw_event["actionGroup"] + assert result["response"]["function"] == "customName" + assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello from custom named function" From e42ceffa7e96cbd8fe9f625729b061f82541142c Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Mon, 28 Apr 2025 22:48:37 -0300 Subject: [PATCH 06/13] add response optional fields --- .../event_handler/__init__.py | 3 +- .../event_handler/bedrock_agent_function.py | 108 +++++++++++++----- tests/events/bedrockAgentFunctionEvent.json | 3 +- .../test_bedrock_agent_functions.py | 104 ++++++++++++----- 4 files changed, 159 insertions(+), 59 deletions(-) diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index db5830d0288..ea7921b9412 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -12,7 +12,7 @@ ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver -from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver +from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver, BedrockResponse from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, @@ -31,6 +31,7 @@ "CORSConfig", "LambdaFunctionUrlResolver", "Response", + "BedrockResponse", "VPCLatticeResolver", "VPCLatticeV2Resolver", ] diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py index 21c5c824ffd..cc5632f0b12 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent_function.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -2,10 +2,6 @@ from typing import TYPE_CHECKING, Any -from typing_extensions import override - -from aws_lambda_powertools.event_handler.api_gateway import Response, ResponseBuilder - if TYPE_CHECKING: from collections.abc import Callable @@ -19,7 +15,49 @@ class ResponseState(Enum): REPROMPT = "REPROMPT" -class BedrockFunctionsResponseBuilder(ResponseBuilder): +class BedrockResponse: + """Response class for Bedrock Agent Functions + + Parameters + ---------- + body : Any, optional + Response body + session_attributes : dict[str, str] | None + Session attributes to include in the response + prompt_session_attributes : dict[str, str] | None + Prompt session attributes to include in the response + status_code : int + Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE) + + Examples + -------- + ```python + @app.tool(description="Function that uses session attributes") + def test_function(): + return BedrockResponse( + body="Hello", + session_attributes={"userId": "123"}, + prompt_session_attributes={"lastAction": "login"} + ) + ``` + """ + + def __init__( + self, + body: Any = None, + session_attributes: dict[str, str] | None = None, + prompt_session_attributes: dict[str, str] | None = None, + knowledge_bases: list[dict[str, Any]] | None = None, + status_code: int = 200, + ) -> None: + self.body = body + self.session_attributes = session_attributes + self.prompt_session_attributes = prompt_session_attributes + self.knowledge_bases = knowledge_bases + self.status_code = status_code + + +class BedrockFunctionsResponseBuilder: """ Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agent Functions. @@ -28,30 +66,50 @@ class BedrockFunctionsResponseBuilder(ResponseBuilder): we override the build method. """ - @override - def build(self, event: BedrockAgentFunctionEvent, *args) -> dict[str, Any]: - """Build the full response dict to be returned by the lambda""" - self._route(event, None) + def __init__(self, result: BedrockResponse | Any, status_code: int = 200) -> None: + self.result = result + self.status_code = status_code if not isinstance(result, BedrockResponse) else result.status_code - body = self.response.body - if self.response.is_json() and not isinstance(self.response.body, str): - body = self.serializer(body) + def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]: + """Build the full response dict to be returned by the lambda""" + if isinstance(self.result, BedrockResponse): + body = self.result.body + session_attributes = self.result.session_attributes + prompt_session_attributes = self.result.prompt_session_attributes + knowledge_bases = self.result.knowledge_bases + else: + body = self.result + session_attributes = None + prompt_session_attributes = None + knowledge_bases = None response: dict[str, Any] = { "messageVersion": "1.0", "response": { "actionGroup": event.action_group, "function": event.function, - "functionResponse": {"responseBody": {"TEXT": {"body": str(body)}}}, + "functionResponse": {"responseBody": {"TEXT": {"body": str(body if body is not None else "")}}}, }, } # Add responseState if it's an error - if self.response.status_code >= 400: + if self.status_code >= 400: response["response"]["functionResponse"]["responseState"] = ( - ResponseState.REPROMPT.value if self.response.status_code == 400 else ResponseState.FAILURE.value + ResponseState.REPROMPT.value if self.status_code == 400 else ResponseState.FAILURE.value ) + # Add session attributes if provided in response or maintain from input + response.update( + { + "sessionAttributes": session_attributes or event.session_attributes or {}, + "promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {}, + }, + ) + + # Add knowledge bases configuration if provided + if knowledge_bases: + response["knowledgeBasesConfiguration"] = knowledge_bases + return response @@ -127,26 +185,20 @@ def _resolve(self) -> dict[str, Any]: function_name = self.current_event.function if function_name not in self._tools: - return self._response_builder_class( - Response( - status_code=400, # Using 400 to trigger REPROMPT + return BedrockFunctionsResponseBuilder( + BedrockResponse( body=f"Function not found: {function_name}", + status_code=400, # Using 400 to trigger REPROMPT ), ).build(self.current_event) try: result = self._tools[function_name]["function"]() - # Always wrap the result in a Response object - if not isinstance(result, Response): - result = Response( - status_code=200, # Success - body=result, - ) - return self._response_builder_class(result).build(self.current_event) + return BedrockFunctionsResponseBuilder(result).build(self.current_event) except Exception as e: - return self._response_builder_class( - Response( - status_code=500, # Using 500 to trigger FAILURE + return BedrockFunctionsResponseBuilder( + BedrockResponse( body=f"Error: {str(e)}", + status_code=500, # Using 500 to trigger FAILURE ), ).build(self.current_event) diff --git a/tests/events/bedrockAgentFunctionEvent.json b/tests/events/bedrockAgentFunctionEvent.json index 043b4226595..e849c3e6b73 100644 --- a/tests/events/bedrockAgentFunctionEvent.json +++ b/tests/events/bedrockAgentFunctionEvent.json @@ -8,8 +8,7 @@ }, "sessionId": "123456789123458", "sessionAttributes": { - "employeeId": "EMP123", - "department": "Engineering" + "employeeId": "EMP123" }, "promptSessionAttributes": { "lastInteraction": "2024-02-01T15:30:00Z", diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index 71f1c852913..a853c2a1e8a 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -1,10 +1,10 @@ from __future__ import annotations -import json +from typing import Any import pytest -from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, Response, content_types +from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, BedrockResponse from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent from tests.functional.utils import load_event @@ -69,32 +69,6 @@ def test_bedrock_agent_function_not_found(): assert result["response"]["functionResponse"]["responseState"] == "REPROMPT" -def test_bedrock_agent_function_with_response_object(): - # GIVEN a Bedrock Agent Function resolver - app = BedrockAgentFunctionResolver() - - @app.tool(description="Returns a Response object") - def test_function(): - return Response( - status_code=200, - content_type=content_types.APPLICATION_JSON, - body={"message": "Hello from Response"}, - ) - - # WHEN calling the event handler - raw_event = load_event("bedrockAgentFunctionEvent.json") - raw_event["function"] = "test_function" - result = app.resolve(raw_event, {}) - - # THEN process event correctly with Response object - assert result["messageVersion"] == "1.0" - assert result["response"]["actionGroup"] == raw_event["actionGroup"] - assert result["response"]["function"] == "test_function" - response_body = result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] - assert json.loads(response_body) == {"message": "Hello from Response"} - assert "responseState" not in result["response"]["functionResponse"] # Success has no state - - def test_bedrock_agent_function_registration(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() @@ -148,3 +122,77 @@ def test_function(): assert result["response"]["actionGroup"] == raw_event["actionGroup"] assert result["response"]["function"] == "customName" assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello from custom named function" + + +def test_bedrock_agent_function_with_session_attributes(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(description="Function that uses session attributes") + def test_function() -> dict[str, Any]: + return BedrockResponse( + body="Hello", + session_attributes={"userId": "123"}, + prompt_session_attributes={"lastAction": "login"}, + ) + + # WHEN calling the event handler + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "test_function" + raw_event["parameters"] = [] + result = app.resolve(raw_event, {}) + + # THEN include session attributes in response + assert result["messageVersion"] == "1.0" + assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello" + assert result["sessionAttributes"] == {"userId": "123"} + assert result["promptSessionAttributes"] == {"lastAction": "login"} + + +def test_bedrock_agent_function_with_error_response(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(description="Function that returns error") + def test_function() -> dict[str, Any]: + return BedrockResponse( + body="Invalid input", + status_code=400, # This will trigger REPROMPT + session_attributes={"error": "true"}, + ) + + # WHEN calling the event handler + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "test_function" + raw_event["parameters"] = [] + result = app.resolve(raw_event, {}) + + # THEN include error state and session attributes + assert result["response"]["functionResponse"]["responseState"] == "REPROMPT" + assert result["sessionAttributes"] == {"error": "true"} + + +def test_bedrock_agent_function_with_knowledge_bases(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + @app.tool(description="Returns response with knowledge bases config") + def test_function() -> dict[Any]: + return BedrockResponse( + knowledge_bases=[ + { + "knowledgeBaseId": "kb1", + "retrievalConfiguration": {"vectorSearchConfiguration": {"numberOfResults": 5}}, + }, + ], + ) + + # WHEN calling the event handler + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "test_function" + result = app.resolve(raw_event, {}) + + # THEN include knowledge bases in response + assert "knowledgeBasesConfiguration" in result + assert len(result["knowledgeBasesConfiguration"]) == 1 + assert result["knowledgeBasesConfiguration"][0]["knowledgeBaseId"] == "kb1" From 86c7ab72f416fdc9df880f03e13e587e27c64e6f Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Wed, 30 Apr 2025 16:17:20 -0300 Subject: [PATCH 07/13] bedrockfunctionresponse and response state --- .../event_handler/__init__.py | 7 +- .../event_handler/bedrock_agent_function.py | 44 ++--- .../bedrock_agent_function_event.py | 28 --- .../test_bedrock_agent_functions.py | 160 +++++------------- .../test_bedrock_agent_function_event.py | 33 ---- 5 files changed, 59 insertions(+), 213 deletions(-) diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index ea7921b9412..c05539e50eb 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -12,7 +12,10 @@ ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver -from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver, BedrockResponse +from aws_lambda_powertools.event_handler.bedrock_agent_function import ( + BedrockAgentFunctionResolver, + BedrockFunctionResponse, +) from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, @@ -31,7 +34,7 @@ "CORSConfig", "LambdaFunctionUrlResolver", "Response", - "BedrockResponse", + "BedrockFunctionResponse", "VPCLatticeResolver", "VPCLatticeV2Resolver", ] diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py index cc5632f0b12..52e7e495d03 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent_function.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -5,17 +5,10 @@ if TYPE_CHECKING: from collections.abc import Callable -from enum import Enum - from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent -class ResponseState(Enum): - FAILURE = "FAILURE" - REPROMPT = "REPROMPT" - - -class BedrockResponse: +class BedrockFunctionResponse: """Response class for Bedrock Agent Functions Parameters @@ -26,15 +19,15 @@ class BedrockResponse: Session attributes to include in the response prompt_session_attributes : dict[str, str] | None Prompt session attributes to include in the response - status_code : int - Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE) + response_state : str | None + Response state ("FAILURE" or "REPROMPT") Examples -------- ```python @app.tool(description="Function that uses session attributes") def test_function(): - return BedrockResponse( + return BedrockFunctionResponse( body="Hello", session_attributes={"userId": "123"}, prompt_session_attributes={"lastAction": "login"} @@ -48,40 +41,39 @@ def __init__( session_attributes: dict[str, str] | None = None, prompt_session_attributes: dict[str, str] | None = None, knowledge_bases: list[dict[str, Any]] | None = None, - status_code: int = 200, + response_state: str | None = None, ) -> None: self.body = body self.session_attributes = session_attributes self.prompt_session_attributes = prompt_session_attributes self.knowledge_bases = knowledge_bases - self.status_code = status_code + self.response_state = response_state class BedrockFunctionsResponseBuilder: """ Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agent Functions. - - Since the payload format is different from the standard API Gateway Proxy event, - we override the build method. """ - def __init__(self, result: BedrockResponse | Any, status_code: int = 200) -> None: + def __init__(self, result: BedrockFunctionResponse | Any) -> None: self.result = result - self.status_code = status_code if not isinstance(result, BedrockResponse) else result.status_code def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]: """Build the full response dict to be returned by the lambda""" - if isinstance(self.result, BedrockResponse): + if isinstance(self.result, BedrockFunctionResponse): body = self.result.body session_attributes = self.result.session_attributes prompt_session_attributes = self.result.prompt_session_attributes knowledge_bases = self.result.knowledge_bases + response_state = self.result.response_state + else: body = self.result session_attributes = None prompt_session_attributes = None knowledge_bases = None + response_state = None response: dict[str, Any] = { "messageVersion": "1.0", @@ -92,11 +84,9 @@ def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]: }, } - # Add responseState if it's an error - if self.status_code >= 400: - response["response"]["functionResponse"]["responseState"] = ( - ResponseState.REPROMPT.value if self.status_code == 400 else ResponseState.FAILURE.value - ) + # Add responseState if provided + if response_state: + response["response"]["functionResponse"]["responseState"] = response_state # Add session attributes if provided in response or maintain from input response.update( @@ -186,9 +176,8 @@ def _resolve(self) -> dict[str, Any]: if function_name not in self._tools: return BedrockFunctionsResponseBuilder( - BedrockResponse( + BedrockFunctionResponse( body=f"Function not found: {function_name}", - status_code=400, # Using 400 to trigger REPROMPT ), ).build(self.current_event) @@ -197,8 +186,7 @@ def _resolve(self) -> dict[str, Any]: return BedrockFunctionsResponseBuilder(result).build(self.current_event) except Exception as e: return BedrockFunctionsResponseBuilder( - BedrockResponse( + BedrockFunctionResponse( body=f"Error: {str(e)}", - status_code=500, # Using 500 to trigger FAILURE ), ).build(self.current_event) diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py index 69c48824ccc..ab479c59381 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any - from aws_lambda_powertools.utilities.data_classes.common import DictWrapper @@ -45,32 +43,6 @@ class BedrockAgentFunctionEvent(DictWrapper): https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html """ - @classmethod - def validate_required_fields(cls, data: dict[str, Any]) -> None: - required_fields = { - "messageVersion": str, - "agent": dict, - "inputText": str, - "sessionId": str, - "actionGroup": str, - "function": str, - } - - for field, field_type in required_fields.items(): - if field not in data: - raise ValueError(f"Missing required field: {field}") - if not isinstance(data[field], field_type): - raise TypeError(f"Field {field} must be of type {field_type}") - - # Validate agent structure - required_agent_fields = {"name", "id", "alias", "version"} - if not all(field in data["agent"] for field in required_agent_fields): - raise ValueError("Agent object missing required fields") - - def __init__(self, data: dict[str, Any]) -> None: - super().__init__(data) - self.validate_required_fields(data) - @property def message_version(self) -> str: return self["messageVersion"] diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index a853c2a1e8a..80b614b4886 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -1,10 +1,8 @@ from __future__ import annotations -from typing import Any - import pytest -from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, BedrockResponse +from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, BedrockFunctionResponse from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent from tests.functional.utils import load_event @@ -28,171 +26,89 @@ def test_function(): assert result["response"]["actionGroup"] == raw_event["actionGroup"] assert result["response"]["function"] == "test_function" assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello from string" - assert "responseState" not in result["response"]["functionResponse"] # Success has no state + assert "responseState" not in result["response"]["functionResponse"] -def test_bedrock_agent_function_with_error(): +def test_bedrock_agent_function_error_handling(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() - @app.tool(description="Function that raises error") + @app.tool(description="Function with error handling") def error_function(): + return BedrockFunctionResponse( + body="Invalid input", + response_state="REPROMPT", + session_attributes={"error": "true"} + ) + + @app.tool(description="Function that raises error") + def exception_function(): raise ValueError("Something went wrong") - # WHEN calling the event handler with a function that raises an error + # WHEN calling with explicit error response raw_event = load_event("bedrockAgentFunctionEvent.json") raw_event["function"] = "error_function" result = app.resolve(raw_event, {}) - # THEN process the error correctly - assert result["messageVersion"] == "1.0" - assert result["response"]["actionGroup"] == raw_event["actionGroup"] - assert result["response"]["function"] == "error_function" - assert "Error: Something went wrong" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] - assert result["response"]["functionResponse"]["responseState"] == "FAILURE" - - -def test_bedrock_agent_function_not_found(): - # GIVEN a Bedrock Agent Function resolver - app = BedrockAgentFunctionResolver() - - # WHEN calling the event handler with a non-existent function - raw_event = load_event("bedrockAgentFunctionEvent.json") - raw_event["function"] = "nonexistent_function" - result = app.resolve(raw_event, {}) - - # THEN return function not found response with REPROMPT state - assert result["messageVersion"] == "1.0" - assert result["response"]["actionGroup"] == raw_event["actionGroup"] - assert result["response"]["function"] == "nonexistent_function" - assert "Function not found" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] + # THEN include REPROMPT state and session attributes assert result["response"]["functionResponse"]["responseState"] == "REPROMPT" + assert result["sessionAttributes"] == {"error": "true"} def test_bedrock_agent_function_registration(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() - # WHEN registering a tool without description - # THEN raise ValueError + # WHEN registering without description or with duplicate name with pytest.raises(ValueError, match="Tool description is required"): - @app.tool() def test_function(): return "test" - # WHEN registering the same function twice - # THEN raise ValueError - @app.tool(description="First registration") - def duplicate_function(): + @app.tool(name="custom", description="First registration") + def first_function(): return "test" - with pytest.raises(ValueError, match="Tool 'duplicate_function' already registered"): - - @app.tool(description="Second registration") - def duplicate_function(): # noqa: F811 + with pytest.raises(ValueError, match="Tool 'custom' already registered"): + @app.tool(name="custom", description="Second registration") + def second_function(): return "test" -def test_bedrock_agent_function_invalid_event(): +def test_bedrock_agent_function_with_optional_fields(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() - # WHEN calling with invalid event - # THEN raise ValueError - with pytest.raises(ValueError, match="Missing required field"): - app.resolve({}, {}) - - -def test_bedrock_agent_function_with_custom_name(): - # GIVEN a Bedrock Agent Function resolver - app = BedrockAgentFunctionResolver() - - @app.tool(name="customName", description="Function with custom name") + @app.tool(description="Function with all optional fields") def test_function(): - return "Hello from custom named function" - - # WHEN calling the event handler - raw_event = load_event("bedrockAgentFunctionEvent.json") - raw_event["function"] = "customName" # Use custom name instead of function name - result = app.resolve(raw_event, {}) - - # THEN process event correctly - assert result["messageVersion"] == "1.0" - assert result["response"]["actionGroup"] == raw_event["actionGroup"] - assert result["response"]["function"] == "customName" - assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello from custom named function" - - -def test_bedrock_agent_function_with_session_attributes(): - # GIVEN a Bedrock Agent Function resolver - app = BedrockAgentFunctionResolver() - - @app.tool(description="Function that uses session attributes") - def test_function() -> dict[str, Any]: - return BedrockResponse( + return BedrockFunctionResponse( body="Hello", session_attributes={"userId": "123"}, - prompt_session_attributes={"lastAction": "login"}, + prompt_session_attributes={"context": "test"}, + knowledge_bases=[{ + "knowledgeBaseId": "kb1", + "retrievalConfiguration": { + "vectorSearchConfiguration": {"numberOfResults": 5} + } + }] ) # WHEN calling the event handler raw_event = load_event("bedrockAgentFunctionEvent.json") raw_event["function"] = "test_function" - raw_event["parameters"] = [] result = app.resolve(raw_event, {}) - # THEN include session attributes in response - assert result["messageVersion"] == "1.0" + # THEN include all optional fields in response assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello" assert result["sessionAttributes"] == {"userId": "123"} - assert result["promptSessionAttributes"] == {"lastAction": "login"} - - -def test_bedrock_agent_function_with_error_response(): - # GIVEN a Bedrock Agent Function resolver - app = BedrockAgentFunctionResolver() - - @app.tool(description="Function that returns error") - def test_function() -> dict[str, Any]: - return BedrockResponse( - body="Invalid input", - status_code=400, # This will trigger REPROMPT - session_attributes={"error": "true"}, - ) - - # WHEN calling the event handler - raw_event = load_event("bedrockAgentFunctionEvent.json") - raw_event["function"] = "test_function" - raw_event["parameters"] = [] - result = app.resolve(raw_event, {}) - - # THEN include error state and session attributes - assert result["response"]["functionResponse"]["responseState"] == "REPROMPT" - assert result["sessionAttributes"] == {"error": "true"} + assert result["promptSessionAttributes"] == {"context": "test"} + assert result["knowledgeBasesConfiguration"][0]["knowledgeBaseId"] == "kb1" -def test_bedrock_agent_function_with_knowledge_bases(): +def test_bedrock_agent_function_invalid_event(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() - @app.tool(description="Returns response with knowledge bases config") - def test_function() -> dict[Any]: - return BedrockResponse( - knowledge_bases=[ - { - "knowledgeBaseId": "kb1", - "retrievalConfiguration": {"vectorSearchConfiguration": {"numberOfResults": 5}}, - }, - ], - ) - - # WHEN calling the event handler - raw_event = load_event("bedrockAgentFunctionEvent.json") - raw_event["function"] = "test_function" - result = app.resolve(raw_event, {}) - - # THEN include knowledge bases in response - assert "knowledgeBasesConfiguration" in result - assert len(result["knowledgeBasesConfiguration"]) == 1 - assert result["knowledgeBasesConfiguration"][0]["knowledgeBaseId"] == "kb1" + # WHEN calling with invalid event + with pytest.raises(ValueError, match="Missing required field"): + app.resolve({}, {}) \ No newline at end of file diff --git a/tests/unit/data_classes/required_dependencies/test_bedrock_agent_function_event.py b/tests/unit/data_classes/required_dependencies/test_bedrock_agent_function_event.py index cd41fdd2e4b..e055c894604 100644 --- a/tests/unit/data_classes/required_dependencies/test_bedrock_agent_function_event.py +++ b/tests/unit/data_classes/required_dependencies/test_bedrock_agent_function_event.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent from tests.functional.utils import load_event @@ -61,34 +59,3 @@ def test_bedrock_agent_function_event_minimal(): assert parsed_event.session_attributes == {} assert parsed_event.prompt_session_attributes == {} assert parsed_event.parameters == [] - - -def test_bedrock_agent_function_event_validation(): - """Test validation of required fields""" - # Test missing required field - with pytest.raises(ValueError, match="Missing required field: messageVersion"): - BedrockAgentFunctionEvent({}) - - # Test invalid field type - invalid_event = { - "messageVersion": 1, # should be string - "agent": {"alias": "PROD", "name": "hr-assistant", "version": "1", "id": "1234"}, - "inputText": "", - "sessionId": "", - "actionGroup": "", - "function": "", - } - with pytest.raises(TypeError, match="Field messageVersion must be of type "): - BedrockAgentFunctionEvent(invalid_event) - - # Test missing agent fields - invalid_agent_event = { - "messageVersion": "1.0", - "agent": {"name": "test"}, # missing required agent fields - "inputText": "", - "sessionId": "", - "actionGroup": "", - "function": "", - } - with pytest.raises(ValueError, match="Agent object missing required fields"): - BedrockAgentFunctionEvent(invalid_agent_event) From 34948d7f43e5bce3b02308240d4b8273a8a79be9 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Thu, 1 May 2025 15:54:13 -0300 Subject: [PATCH 08/13] remove body message --- .../event_handler/bedrock_agent_function.py | 7 ------- .../test_bedrock_agent_functions.py | 18 ++++++++++-------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py index 52e7e495d03..9ae04e90102 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent_function.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -174,13 +174,6 @@ def _resolve(self) -> dict[str, Any]: function_name = self.current_event.function - if function_name not in self._tools: - return BedrockFunctionsResponseBuilder( - BedrockFunctionResponse( - body=f"Function not found: {function_name}", - ), - ).build(self.current_event) - try: result = self._tools[function_name]["function"]() return BedrockFunctionsResponseBuilder(result).build(self.current_event) diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index 80b614b4886..6983610cb59 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -38,7 +38,7 @@ def error_function(): return BedrockFunctionResponse( body="Invalid input", response_state="REPROMPT", - session_attributes={"error": "true"} + session_attributes={"error": "true"}, ) @app.tool(description="Function that raises error") @@ -61,6 +61,7 @@ def test_bedrock_agent_function_registration(): # WHEN registering without description or with duplicate name with pytest.raises(ValueError, match="Tool description is required"): + @app.tool() def test_function(): return "test" @@ -70,6 +71,7 @@ def first_function(): return "test" with pytest.raises(ValueError, match="Tool 'custom' already registered"): + @app.tool(name="custom", description="Second registration") def second_function(): return "test" @@ -85,12 +87,12 @@ def test_function(): body="Hello", session_attributes={"userId": "123"}, prompt_session_attributes={"context": "test"}, - knowledge_bases=[{ - "knowledgeBaseId": "kb1", - "retrievalConfiguration": { - "vectorSearchConfiguration": {"numberOfResults": 5} - } - }] + knowledge_bases=[ + { + "knowledgeBaseId": "kb1", + "retrievalConfiguration": {"vectorSearchConfiguration": {"numberOfResults": 5}}, + }, + ], ) # WHEN calling the event handler @@ -111,4 +113,4 @@ def test_bedrock_agent_function_invalid_event(): # WHEN calling with invalid event with pytest.raises(ValueError, match="Missing required field"): - app.resolve({}, {}) \ No newline at end of file + app.resolve({}, {}) From 24978cb76a2d078b61d8210e3bd93cf594f0524a Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Thu, 1 May 2025 16:18:27 -0300 Subject: [PATCH 09/13] add parser --- .../parser/envelopes/bedrock_agent.py | 26 +++++++++++++- .../utilities/parser/models/__init__.py | 2 ++ .../utilities/parser/models/bedrock_agent.py | 18 ++++++++++ .../parser/_pydantic/test_bedrock_agent.py | 34 ++++++++++++++++++- 4 files changed, 78 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/utilities/parser/envelopes/bedrock_agent.py b/aws_lambda_powertools/utilities/parser/envelopes/bedrock_agent.py index 3d234999116..392c17cc425 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/bedrock_agent.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/bedrock_agent.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from aws_lambda_powertools.utilities.parser.envelopes.base import BaseEnvelope -from aws_lambda_powertools.utilities.parser.models import BedrockAgentEventModel +from aws_lambda_powertools.utilities.parser.models import BedrockAgentEventModel, BedrockAgentFunctionEventModel if TYPE_CHECKING: from aws_lambda_powertools.utilities.parser.types import Model @@ -34,3 +34,27 @@ def parse(self, data: dict[str, Any] | Any | None, model: type[Model]) -> Model parsed_envelope: BedrockAgentEventModel = BedrockAgentEventModel.model_validate(data) logger.debug(f"Parsing event payload in `input_text` with {model}") return self._parse(data=parsed_envelope.input_text, model=model) + + +class BedrockAgentFunctionEnvelope(BaseEnvelope): + """Bedrock Agent Function envelope to extract data within input_text key""" + + def parse(self, data: dict[str, Any] | Any | None, model: type[Model]) -> Model | None: + """Parses data found with model provided + + Parameters + ---------- + data : dict + Lambda event to be parsed + model : type[Model] + Data model provided to parse after extracting data using envelope + + Returns + ------- + Model | None + Parsed detail payload with model provided + """ + logger.debug(f"Parsing incoming data with Bedrock Agent Function model {BedrockAgentFunctionEventModel}") + parsed_envelope: BedrockAgentFunctionEventModel = BedrockAgentFunctionEventModel.model_validate(data) + logger.debug(f"Parsing event payload in `input_text` with {model}") + return self._parse(data=parsed_envelope.input_text, model=model) diff --git a/aws_lambda_powertools/utilities/parser/models/__init__.py b/aws_lambda_powertools/utilities/parser/models/__init__.py index 7ea8da2dc22..ad8e3d7a92f 100644 --- a/aws_lambda_powertools/utilities/parser/models/__init__.py +++ b/aws_lambda_powertools/utilities/parser/models/__init__.py @@ -32,6 +32,7 @@ ) from .bedrock_agent import ( BedrockAgentEventModel, + BedrockAgentFunctionEventModel, BedrockAgentModel, BedrockAgentPropertyModel, BedrockAgentRequestBodyModel, @@ -208,6 +209,7 @@ "BedrockAgentEventModel", "BedrockAgentRequestBodyModel", "BedrockAgentRequestMediaModel", + "BedrockAgentFunctionEventModel", "S3BatchOperationJobModel", "S3BatchOperationModel", "S3BatchOperationTaskModel", diff --git a/aws_lambda_powertools/utilities/parser/models/bedrock_agent.py b/aws_lambda_powertools/utilities/parser/models/bedrock_agent.py index 62465162167..1aa5ae07a34 100644 --- a/aws_lambda_powertools/utilities/parser/models/bedrock_agent.py +++ b/aws_lambda_powertools/utilities/parser/models/bedrock_agent.py @@ -36,3 +36,21 @@ class BedrockAgentEventModel(BaseModel): agent: BedrockAgentModel parameters: Optional[List[BedrockAgentPropertyModel]] = None request_body: Optional[BedrockAgentRequestBodyModel] = Field(None, alias="requestBody") + + +class BedrockAgentFunctionEventModel(BaseModel): + """Bedrock Agent Function event model + + Documentation: + https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html + """ + + message_version: str = Field(..., alias="messageVersion") + agent: BedrockAgentModel + input_text: str = Field(..., alias="inputText") + session_id: str = Field(..., alias="sessionId") + action_group: str = Field(..., alias="actionGroup") + function: str + parameters: Optional[List[BedrockAgentPropertyModel]] = None + session_attributes: Dict[str, str] = Field({}, alias="sessionAttributes") + prompt_session_attributes: Dict[str, str] = Field({}, alias="promptSessionAttributes") diff --git a/tests/unit/parser/_pydantic/test_bedrock_agent.py b/tests/unit/parser/_pydantic/test_bedrock_agent.py index 207318952cc..472dfa26eff 100644 --- a/tests/unit/parser/_pydantic/test_bedrock_agent.py +++ b/tests/unit/parser/_pydantic/test_bedrock_agent.py @@ -1,5 +1,5 @@ from aws_lambda_powertools.utilities.parser import envelopes, parse -from aws_lambda_powertools.utilities.parser.models import BedrockAgentEventModel +from aws_lambda_powertools.utilities.parser.models import BedrockAgentEventModel, BedrockAgentFunctionEventModel from tests.functional.utils import load_event from tests.unit.parser._pydantic.schemas import MyBedrockAgentBusiness @@ -76,3 +76,35 @@ def test_bedrock_agent_event_with_post(): assert properties[1].name == raw_properties[1]["name"] assert properties[1].type_ == raw_properties[1]["type"] assert properties[1].value == raw_properties[1]["value"] + + +def test_bedrock_agent_function_event(): + raw_event = load_event("bedrockAgentFunctionEvent.json") + model = BedrockAgentFunctionEventModel(**raw_event) + + assert model.message_version == raw_event["messageVersion"] + assert model.session_id == raw_event["sessionId"] + assert model.input_text == raw_event["inputText"] + assert model.action_group == raw_event["actionGroup"] + assert model.function == raw_event["function"] + assert model.session_attributes == {"employeeId": "EMP123"} + assert model.prompt_session_attributes == {"lastInteraction": "2024-02-01T15:30:00Z", "requestType": "vacation"} + + agent = model.agent + raw_agent = raw_event["agent"] + assert agent.alias == raw_agent["alias"] + assert agent.name == raw_agent["name"] + assert agent.version == raw_agent["version"] + assert agent.id_ == raw_agent["id"] + + parameters = model.parameters + assert parameters is not None + assert len(parameters) == 2 + + assert parameters[0].name == "startDate" + assert parameters[0].type_ == "string" + assert parameters[0].value == "2024-03-15" + + assert parameters[1].name == "endDate" + assert parameters[1].type_ == "string" + assert parameters[1].value == "2024-03-20" From 45f85f6cb858888dedddceccdedae3fa0b691ba7 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Mon, 5 May 2025 18:14:03 -0300 Subject: [PATCH 10/13] add test for required fields --- .../test_bedrock_agent_functions.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index 6983610cb59..30061a3b3ed 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -114,3 +114,21 @@ def test_bedrock_agent_function_invalid_event(): # WHEN calling with invalid event with pytest.raises(ValueError, match="Missing required field"): app.resolve({}, {}) + + +def test_resolve_raises_value_error_on_missing_required_field(): + """Test that resolve() raises ValueError when a required field is missing from the event""" + # GIVEN a Bedrock Agent Function resolver and an incomplete event + resolver = BedrockAgentFunctionResolver() + incomplete_event = { + "messageVersion": "1.0", + "agent": {"alias": "PROD", "name": "hr-assistant-function-def", "version": "1", "id": "1234abcd"}, + "sessionId": "123456789123458", + } + + # WHEN calling resolve with the incomplete event + # THEN a ValueError is raised with information about the missing field + with pytest.raises(ValueError) as excinfo: + resolver.resolve(incomplete_event, {}) + + assert "Missing required field:" in str(excinfo.value) From 84bb6b02ed8f4d8c181fe1d2196b1f5f66d3fda7 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Mon, 5 May 2025 18:37:58 -0300 Subject: [PATCH 11/13] add more tests for parser and resolver --- .../utilities/parser/envelopes/__init__.py | 3 ++- .../test_bedrock_agent_functions.py | 20 +++++++++++++++++++ .../parser/_pydantic/test_bedrock_agent.py | 13 ++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py index e1ac8cdbf5e..0bf4b7a5535 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py @@ -2,7 +2,7 @@ from .apigw_websocket import ApiGatewayWebSocketEnvelope from .apigwv2 import ApiGatewayV2Envelope from .base import BaseEnvelope -from .bedrock_agent import BedrockAgentEnvelope +from .bedrock_agent import BedrockAgentEnvelope, BedrockAgentFunctionEnvelope from .cloudwatch import CloudWatchLogsEnvelope from .dynamodb import DynamoDBStreamEnvelope from .event_bridge import EventBridgeEnvelope @@ -20,6 +20,7 @@ "ApiGatewayV2Envelope", "ApiGatewayWebSocketEnvelope", "BedrockAgentEnvelope", + "BedrockAgentFunctionEnvelope", "CloudWatchLogsEnvelope", "DynamoDBStreamEnvelope", "EventBridgeEnvelope", diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index 30061a3b3ed..151d79dcda7 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -132,3 +132,23 @@ def test_resolve_raises_value_error_on_missing_required_field(): resolver.resolve(incomplete_event, {}) assert "Missing required field:" in str(excinfo.value) + + +def test_resolve_with_no_registered_function(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + # AND a valid event but with a non-existent function + raw_event = { + "messageVersion": "1.0", + "agent": {"name": "TestAgent", "id": "test-id", "alias": "test", "version": "1"}, + "actionGroup": "test_group", + "function": "non_existent_function", + "parameters": [], + } + + # WHEN calling resolve with a non-existent function + result = app.resolve(raw_event, {}) + + # THEN the response should contain an error message + assert "Error: 'non_existent_function'" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] diff --git a/tests/unit/parser/_pydantic/test_bedrock_agent.py b/tests/unit/parser/_pydantic/test_bedrock_agent.py index 472dfa26eff..e66a202a53f 100644 --- a/tests/unit/parser/_pydantic/test_bedrock_agent.py +++ b/tests/unit/parser/_pydantic/test_bedrock_agent.py @@ -108,3 +108,16 @@ def test_bedrock_agent_function_event(): assert parameters[1].name == "endDate" assert parameters[1].type_ == "string" assert parameters[1].value == "2024-03-20" + + +def test_bedrock_agent_function_event_with_envelope(): + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["inputText"] = '{"username": "Jane", "name": "Doe"}' + parsed_event: MyBedrockAgentBusiness = parse( + event=raw_event, + model=MyBedrockAgentBusiness, + envelope=envelopes.BedrockAgentFunctionEnvelope, + ) + + assert parsed_event.username == "Jane" + assert parsed_event.name == "Doe" From d4633046cac920ec8e7e7ddd41d38cb9eab4e1de Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Mon, 5 May 2025 20:29:07 -0300 Subject: [PATCH 12/13] add validation response state --- .../event_handler/__init__.py | 1 - .../event_handler/bedrock_agent_function.py | 3 +++ .../test_bedrock_agent_functions.py | 21 +++++++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 33ca5e7d0b0..f374590428d 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -16,7 +16,6 @@ BedrockAgentFunctionResolver, BedrockFunctionResponse, ) - from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py index 9ae04e90102..7538c1600ae 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent_function.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -43,6 +43,9 @@ def __init__( knowledge_bases: list[dict[str, Any]] | None = None, response_state: str | None = None, ) -> None: + if response_state is not None and response_state not in ["FAILURE", "REPROMPT"]: + raise ValueError("responseState must be None, 'FAILURE' or 'REPROMPT'") + self.body = body self.session_attributes = session_attributes self.prompt_session_attributes = prompt_session_attributes diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index 151d79dcda7..c608db172bd 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -152,3 +152,24 @@ def test_resolve_with_no_registered_function(): # THEN the response should contain an error message assert "Error: 'non_existent_function'" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] + + +def test_bedrock_function_response_state_validation(): + # GIVEN invalid and valid response states + valid_states = [None, "FAILURE", "REPROMPT"] + invalid_state = "INVALID" + + # WHEN creating responses with valid states + # THEN no error should be raised + for state in valid_states: + try: + BedrockFunctionResponse(body="test", response_state=state) + except ValueError: + pytest.fail(f"Unexpected ValueError for response_state={state}") + + # WHEN creating a response with invalid state + # THEN ValueError should be raised with correct message + with pytest.raises(ValueError) as exc_info: + BedrockFunctionResponse(body="test", response_state=invalid_state) + + assert str(exc_info.value) == "responseState must be None, 'FAILURE' or 'REPROMPT'" From 54a7edf272bdc903533c8105acb7f930d77951e6 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Thu, 8 May 2025 20:15:08 -0300 Subject: [PATCH 13/13] params injection --- .../event_handler/bedrock_agent_function.py | 38 ++++++++++--- .../test_bedrock_agent_functions.py | 56 +++++++++++++++---- 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent_function.py b/aws_lambda_powertools/event_handler/bedrock_agent_function.py index 7538c1600ae..20c16f48f5d 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent_function.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent_function.py @@ -1,6 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +import inspect +import warnings +from typing import TYPE_CHECKING, Any, Literal + +from aws_lambda_powertools.warnings import PowertoolsUserWarning if TYPE_CHECKING: from collections.abc import Callable @@ -19,7 +23,7 @@ class BedrockFunctionResponse: Session attributes to include in the response prompt_session_attributes : dict[str, str] | None Prompt session attributes to include in the response - response_state : str | None + response_state : Literal["FAILURE", "REPROMPT"] | None Response state ("FAILURE" or "REPROMPT") Examples @@ -41,10 +45,10 @@ def __init__( session_attributes: dict[str, str] | None = None, prompt_session_attributes: dict[str, str] | None = None, knowledge_bases: list[dict[str, Any]] | None = None, - response_state: str | None = None, + response_state: Literal["FAILURE", "REPROMPT"] | None = None, ) -> None: if response_state is not None and response_state not in ["FAILURE", "REPROMPT"]: - raise ValueError("responseState must be None, 'FAILURE' or 'REPROMPT'") + raise ValueError("responseState must be 'FAILURE' or 'REPROMPT'") self.body = body self.session_attributes = session_attributes @@ -78,6 +82,8 @@ def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]: knowledge_bases = None response_state = None + # Per AWS Bedrock documentation, currently only "TEXT" is supported as the responseBody content type + # https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html response: dict[str, Any] = { "messageVersion": "1.0", "response": { @@ -147,12 +153,13 @@ def tool( """ def decorator(func: Callable) -> Callable: - if not description: - raise ValueError("Tool description is required") - function_name = name or func.__name__ if function_name in self._tools: - raise ValueError(f"Tool '{function_name}' already registered") + warnings.warn( + f"Tool '{function_name}' already registered. Overwriting with new definition.", + PowertoolsUserWarning, + stacklevel=2, + ) self._tools[function_name] = { "function": func, @@ -178,7 +185,20 @@ def _resolve(self) -> dict[str, Any]: function_name = self.current_event.function try: - result = self._tools[function_name]["function"]() + parameters = {} + if hasattr(self.current_event, "parameters"): + for param in self.current_event.parameters: + parameters[param.name] = param.value + + func = self._tools[function_name]["function"] + sig = inspect.signature(func) + + valid_params = {} + for name, value in parameters.items(): + if name in sig.parameters: + valid_params[name] = value + + result = func(**valid_params) return BedrockFunctionsResponseBuilder(result).build(self.current_event) except Exception as e: return BedrockFunctionsResponseBuilder( diff --git a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py index c608db172bd..9cfebc51d7e 100644 --- a/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py +++ b/tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py @@ -4,6 +4,7 @@ from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, BedrockFunctionResponse from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent +from aws_lambda_powertools.warnings import PowertoolsUserWarning from tests.functional.utils import load_event @@ -59,22 +60,25 @@ def test_bedrock_agent_function_registration(): # GIVEN a Bedrock Agent Function resolver app = BedrockAgentFunctionResolver() - # WHEN registering without description or with duplicate name - with pytest.raises(ValueError, match="Tool description is required"): - - @app.tool() - def test_function(): - return "test" - + # WHEN registering with duplicate name @app.tool(name="custom", description="First registration") def first_function(): - return "test" + return "first test" - with pytest.raises(ValueError, match="Tool 'custom' already registered"): + # THEN a warning should be issued when registering a duplicate + with pytest.warns(PowertoolsUserWarning, match="Tool 'custom' already registered"): @app.tool(name="custom", description="Second registration") def second_function(): - return "test" + return "second test" + + # AND the most recent function should be registered + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "custom" + result = app.resolve(raw_event, {}) + + # The second function should be used + assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "second test" def test_bedrock_agent_function_with_optional_fields(): @@ -156,7 +160,7 @@ def test_resolve_with_no_registered_function(): def test_bedrock_function_response_state_validation(): # GIVEN invalid and valid response states - valid_states = [None, "FAILURE", "REPROMPT"] + valid_states = ["FAILURE", "REPROMPT"] invalid_state = "INVALID" # WHEN creating responses with valid states @@ -172,4 +176,32 @@ def test_bedrock_function_response_state_validation(): with pytest.raises(ValueError) as exc_info: BedrockFunctionResponse(body="test", response_state=invalid_state) - assert str(exc_info.value) == "responseState must be None, 'FAILURE' or 'REPROMPT'" + assert str(exc_info.value) == "responseState must be 'FAILURE' or 'REPROMPT'" + + +def test_bedrock_agent_function_with_parameters(): + # GIVEN a Bedrock Agent Function resolver + app = BedrockAgentFunctionResolver() + + # Track received parameters + received_params = {} + + @app.tool(description="Function that accepts parameters") + def vacation_request(startDate, endDate): + # Store received parameters for assertion + received_params["startDate"] = startDate + received_params["endDate"] = endDate + return f"Vacation request from {startDate} to {endDate} submitted" + + # WHEN calling the event handler with parameters + raw_event = load_event("bedrockAgentFunctionEvent.json") + raw_event["function"] = "vacation_request" + result = app.resolve(raw_event, {}) + + # THEN parameters should be correctly passed to the function + assert received_params["startDate"] == "2024-03-15" + assert received_params["endDate"] == "2024-03-20" + assert ( + "Vacation request from 2024-03-15 to 2024-03-20 submitted" + in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] + )