diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py
index b716087d38b..2163d7d762e 100644
--- a/aws_lambda_powertools/event_handler/api_gateway.py
+++ b/aws_lambda_powertools/event_handler/api_gateway.py
@@ -9,19 +9,7 @@
from enum import Enum
from functools import partial
from http import HTTPStatus
-from typing import (
- Any,
- Callable,
- Dict,
- List,
- Match,
- Optional,
- Pattern,
- Set,
- Tuple,
- Type,
- Union,
-)
+from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union
from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
@@ -218,13 +206,129 @@ def __init__(
cors: bool,
compress: bool,
cache_control: Optional[str],
+ middlewares: Optional[List[Callable[..., Response]]],
):
+ """
+
+ Parameters
+ ----------
+
+ method: str
+ The HTTP method, example "GET"
+ rule: Pattern
+ The route rule, example "/my/path"
+ func: Callable
+ The route handler function
+ cors: bool
+ Whether or not to enable CORS for this route
+ compress: bool
+ Whether or not to enable gzip compression for this route
+ cache_control: Optional[str]
+ The cache control header value, example "max-age=3600"
+ middlewares: Optional[List[Callable[..., Response]]]
+ The list of route middlewares to be called in order.
+ """
self.method = method.upper()
self.rule = rule
self.func = func
+ self._middleware_stack = func
self.cors = cors
self.compress = compress
self.cache_control = cache_control
+ self.middlewares = middlewares or []
+
+ # _middleware_stack_built is used to ensure the middleware stack is only built once.
+ self._middleware_stack_built = False
+
+ def __call__(
+ self,
+ router_middlewares: List[Callable],
+ app: "ApiGatewayResolver",
+ route_arguments: Dict[str, str],
+ ) -> Union[Dict, Tuple, Response]:
+ """Calling the Router class instance will trigger the following actions:
+ 1. If Route Middleware stack has not been built, build it
+ 2. Call the Route Middleware stack wrapping the original function
+ handler with the app and route arguments.
+
+ Parameters
+ ----------
+ router_middlewares: List[Callable]
+ The list of Router Middlewares (assigned to ALL routes)
+ app: "ApiGatewayResolver"
+ The ApiGatewayResolver instance to pass into the middleware stack
+ route_arguments: Dict[str, str]
+ The route arguments to pass to the app function (extracted from the Api Gateway
+ Lambda Message structure from AWS)
+
+ Returns
+ -------
+ Union[Dict, Tuple, Response]
+ API Response object in ALL cases, except when the original API route
+ handler is called which may also return a Dict, Tuple, or Response.
+ """
+
+ # Save CPU cycles by building middleware stack once
+ if not self._middleware_stack_built:
+ self._build_middleware_stack(router_middlewares=router_middlewares)
+
+ # If debug is turned on then output the middleware stack to the console
+ if app._debug:
+ print(f"\nProcessing Route:::{self.func.__name__} ({app.context['_path']})")
+ # Collect ALL middleware for debug printing - include internal _registered_api_adapter
+ all_middlewares = router_middlewares + self.middlewares + [_registered_api_adapter]
+ print("\nMiddleware Stack:")
+ print("=================")
+ print("\n".join(getattr(item, "__name__", "Unknown") for item in all_middlewares))
+ print("=================")
+
+ # Add Route Arguments to app context
+ app.append_context(_route_args=route_arguments)
+
+ # Call the Middleware Wrapped _call_stack function handler with the app
+ return self._middleware_stack(app)
+
+ def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) -> None:
+ """
+ Builds the middleware stack for the handler by wrapping each
+ handler in an instance of MiddlewareWrapper which is used to contain the state
+ of each middleware step.
+
+ Middleware is represented by a standard Python Callable construct. Any Middleware
+ handler wanting to short-circuit the middlware call chain can raise an exception
+ to force the Python call stack created by the handler call-chain to naturally un-wind.
+
+ This becomes a simple concept for developers to understand and reason with - no additional
+ gymanstics other than plain old try ... except.
+
+ Notes
+ -----
+ The Route Middleware stack is processed in reverse order. This is so the stack of
+ middleware handlers is applied in the order of being added to the handler.
+ """
+ all_middlewares = router_middlewares + self.middlewares
+ logger.debug(f"Building middleware stack: {all_middlewares}")
+
+ # IMPORTANT:
+ # this must be the last middleware in the stack (tech debt for backward
+ # compatibility purposes)
+ #
+ # This adapter will:
+ # 1. Call the registered API passing only the expected route arguments extracted from the path
+ # and not the middleware.
+ # 2. Adapt the response type of the route handler (Union[Dict, Tuple, Response])
+ # and normalise into a Response object so middleware will always have a constant signature
+ all_middlewares.append(_registered_api_adapter)
+
+ # Wrap the original route handler function in the middleware handlers
+ # using the MiddlewareWrapper class callable construct in reverse order to
+ # ensure middleware is applied in the order the user defined.
+ #
+ # Start with the route function and wrap from last to the first Middleware handler.
+ for handler in reversed(all_middlewares):
+ self._middleware_stack = MiddlewareFrame(current_middleware=handler, next_middleware=self._middleware_stack)
+
+ self._middleware_stack_built = True
class ResponseBuilder:
@@ -268,7 +372,11 @@ def _has_compression_enabled(
bool
True if compression is enabled and the "gzip" encoding is accepted, False otherwise.
"""
- encoding: str = event.get_header_value(name="accept-encoding", default_value="", case_sensitive=False) # type: ignore[assignment] # noqa: E501
+ encoding: str = event.get_header_value(
+ name="accept-encoding",
+ default_value="",
+ case_sensitive=False,
+ ) # noqa: E501
if "gzip" in encoding:
if response_compression is not None:
return response_compression # e.g., Response(compress=False/True))
@@ -322,6 +430,8 @@ class BaseRouter(ABC):
current_event: BaseProxyEvent
lambda_context: LambdaContext
context: dict
+ _router_middlewares: List[Callable] = []
+ processed_stack_frames: List[str] = []
@abstractmethod
def route(
@@ -331,10 +441,59 @@ def route(
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable[..., Any]]] = None,
):
raise NotImplementedError()
- def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
+ def use(self, middlewares: List[Callable[..., Response]]) -> None:
+ """
+ Add one or more global middlewares that run before/after route specific middleware.
+
+ NOTE: Middlewares are called in insertion order.
+
+ Parameters
+ ----------
+ middlewares: List[Callable[..., Response]]
+ List of global middlewares to be used
+
+ Examples
+ --------
+
+ Add middlewares to be used for every request processed by the Router.
+
+ ```python
+ from aws_lambda_powertools import Logger
+ from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
+ from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
+
+ logger = Logger()
+ app = APIGatewayRestResolver()
+
+ def log_request_response(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
+ logger.info("Incoming request", path=app.current_event.path, request=app.current_event.raw_event)
+
+ result = next_middleware(app)
+ logger.info("Response received", response=result.__dict__)
+
+ return result
+
+ app.use(middlewares=[log_request_response])
+
+
+ def lambda_handler(event, context):
+ return app.resolve(event, context)
+ ```
+ """
+ self._router_middlewares = self._router_middlewares + middlewares
+
+ def get(
+ self,
+ rule: str,
+ cors: Optional[bool] = None,
+ compress: bool = False,
+ cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable[..., Any]]] = None,
+ ):
"""Get route decorator with GET `method`
Examples
@@ -357,9 +516,16 @@ def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
- return self.route(rule, "GET", cors, compress, cache_control)
+ return self.route(rule, "GET", cors, compress, cache_control, middlewares)
- def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
+ def post(
+ self,
+ rule: str,
+ cors: Optional[bool] = None,
+ compress: bool = False,
+ cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable[..., Any]]] = None,
+ ):
"""Post route decorator with POST `method`
Examples
@@ -383,9 +549,16 @@ def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
- return self.route(rule, "POST", cors, compress, cache_control)
+ return self.route(rule, "POST", cors, compress, cache_control, middlewares)
- def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
+ def put(
+ self,
+ rule: str,
+ cors: Optional[bool] = None,
+ compress: bool = False,
+ cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable[..., Any]]] = None,
+ ):
"""Put route decorator with PUT `method`
Examples
@@ -409,7 +582,7 @@ def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
- return self.route(rule, "PUT", cors, compress, cache_control)
+ return self.route(rule, "PUT", cors, compress, cache_control, middlewares)
def delete(
self,
@@ -417,6 +590,7 @@ def delete(
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable[..., Any]]] = None,
):
"""Delete route decorator with DELETE `method`
@@ -440,7 +614,7 @@ def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
- return self.route(rule, "DELETE", cors, compress, cache_control)
+ return self.route(rule, "DELETE", cors, compress, cache_control, middlewares)
def patch(
self,
@@ -448,6 +622,7 @@ def patch(
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable]] = None,
):
"""Patch route decorator with PATCH `method`
@@ -474,7 +649,19 @@ def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
- return self.route(rule, "PATCH", cors, compress, cache_control)
+ return self.route(rule, "PATCH", cors, compress, cache_control, middlewares)
+
+ def _push_processed_stack_frame(self, frame: str):
+ """
+ Add Current Middleware to the Middleware Stack Frames
+ The stack frames will be used when exceptions are thrown and Powertools
+ debug is enabled by developers.
+ """
+ self.processed_stack_frames.append(frame)
+
+ def _reset_processed_stack(self):
+ """Reset the Processed Stack Frames"""
+ self.processed_stack_frames.clear()
def append_context(self, **additional_context):
"""Append key=value data as routing context"""
@@ -485,6 +672,109 @@ def clear_context(self):
self.context.clear()
+class MiddlewareFrame:
+ """
+ creates a Middle Stack Wrapper instance to be used as a "Frame" in the overall stack of
+ middleware functions. Each instance contains the current middleware and the next
+ middleware function to be called in the stack.
+
+ In this way the middleware stack is constructed in a recursive fashion, with each middleware
+ calling the next as a simple function call. The actual Python call-stack will contain
+ each MiddlewareStackWrapper "Frame", meaning any Middleware function can cause the
+ entire Middleware call chain to be exited early (short-circuited) by raising an exception
+ or by simply returning early with a custom Response. The decision to short-circuit the middleware
+ chain is at the user's discretion but instantly available due to the Wrapped nature of the
+ callable constructs in the Middleware stack and each Middleware function having complete control over
+ whether the "Next" handler in the stack is called or not.
+
+ Parameters
+ ----------
+ current_middleware : Callable
+ The current middleware function to be called as a request is processed.
+ next_middleware : Callable
+ The next middleware in the middleware stack.
+ """
+
+ def __init__(
+ self,
+ current_middleware: Callable[..., Any],
+ next_middleware: Callable[..., Any],
+ ) -> None:
+ self.current_middleware: Callable[..., Any] = current_middleware
+ self.next_middleware: Callable[..., Any] = next_middleware
+ self._next_middleware_name = next_middleware.__name__
+
+ @property
+ def __name__(self) -> str: # noqa: A003
+ """Current middleware name
+
+ It ensures backward compatibility with view functions being callable. This
+ improves debugging since we need both current and next middlewares/callable names.
+ """
+ return self.current_middleware.__name__
+
+ def __str__(self) -> str:
+ """Identify current middleware identity and call chain for debugging purposes."""
+ middleware_name = self.__name__
+ return f"[{middleware_name}] next call chain is {middleware_name} -> {self._next_middleware_name}"
+
+ def __call__(self, app: "ApiGatewayResolver") -> Union[Dict, Tuple, Response]:
+ """
+ Call the middleware Frame to process the request.
+
+ Parameters
+ ----------
+ app: BaseRouter
+ The router instance
+
+ Returns
+ -------
+ Union[Dict, Tuple, Response]
+ (tech-debt for backward compatibility). The response type should be a
+ Response object in all cases excepting when the original API route handler
+ is called which will return one of 3 outputs.
+
+ """
+ # Do debug printing and push processed stack frame AFTER calling middleware
+ # else the stack frame text of `current calling next` is confusing.
+ logger.debug("MiddlewareFrame: %s", self)
+ app._push_processed_stack_frame(str(self))
+
+ return self.current_middleware(app, self.next_middleware)
+
+
+def _registered_api_adapter(
+ app: "ApiGatewayResolver",
+ next_middleware: Callable[..., Any],
+) -> Union[Dict, Tuple, Response]:
+ """
+ Calls the registered API using the "_route_args" from the Resolver context to ensure the last call
+ in the chain will match the API route function signature and ensure that Powertools passes the API
+ route handler the expected arguments.
+
+ **IMPORTANT: This internal middleware ensures the actual API route is called with the correct call signature
+ and it MUST be the final frame in the middleware stack. This can only be removed when the API Route
+ function accepts `app: BaseRouter` as the first argument - which is the breaking change.
+
+ Parameters
+ ----------
+ app: ApiGatewayResolver
+ The API Gateway resolver
+ next_middleware: Callable[..., Any]
+ The function to handle the API
+
+ Returns
+ -------
+ Response
+ The API Response Object
+
+ """
+ route_args: Dict = app.context.get("_route_args", {})
+ logger.debug(f"Calling API Route Handler: {route_args}")
+
+ return app._to_response(next_middleware(**route_args))
+
+
class ApiGatewayResolver(BaseRouter):
"""API Gateway and ALB proxy resolver
@@ -550,6 +840,7 @@ def __init__(
self._debug = self._has_debug(debug)
self._strip_prefixes = strip_prefixes
self.context: Dict = {} # early init as customers might add context before event resolution
+ self.processed_stack_frames = []
# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
@@ -561,6 +852,7 @@ def route(
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable[..., Any]]] = None,
):
"""Route decorator includes parameter `method`"""
@@ -573,7 +865,15 @@ def register_resolver(func: Callable):
cors_enabled = cors
for item in methods:
- _route = Route(item, self._compile_regex(rule), func, cors_enabled, compress, cache_control)
+ _route = Route(
+ item,
+ self._compile_regex(rule),
+ func,
+ cors_enabled,
+ compress,
+ cache_control,
+ middlewares,
+ )
# The more specific route wins.
# We store dynamic (/studies/{studyid}) and static routes (/studies/fetch) separately.
@@ -594,6 +894,7 @@ def register_resolver(func: Callable):
if cors_enabled:
logger.debug(f"Registering method {item.upper()} to Allow Methods in CORS")
self._cors_methods.add(item.upper())
+
return func
return register_resolver
@@ -628,7 +929,16 @@ def resolve(self, event, context) -> Dict[str, Any]:
BaseRouter.lambda_context = context
response = self._resolve().build(self.current_event, self._cors)
+
+ # Debug print Processed Middlewares
+ if self._debug:
+ print("\nProcessed Middlewares:")
+ print("======================")
+ print("\n".join(self.processed_stack_frames))
+ print("======================")
+
self.clear_context()
+
return response
def __call__(self, event, context) -> Any:
@@ -703,6 +1013,9 @@ def _resolve(self) -> ResponseBuilder:
match_results: Optional[Match] = route.rule.match(path)
if match_results:
logger.debug("Found a registered route. Calling function")
+ # Add matched Route reference into the Resolver context
+ self.append_context(_route=route, _path=path)
+
return self._call_route(route, match_results.groupdict()) # pass fn args
logger.debug(f"No match found for path {path} and method {method}")
@@ -765,15 +1078,25 @@ def _not_found(self, method: str) -> ResponseBuilder:
),
)
- def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
+ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
try:
- return ResponseBuilder(self._to_response(route.func(**args)), route)
+ # Reset Processed stack for Middleware (for debugging purposes)
+ self._reset_processed_stack()
+
+ return ResponseBuilder(
+ self._to_response(
+ route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
+ ),
+ route,
+ )
except Exception as exc:
+ # If exception is handled then return the response builder to reduce noise
response_builder = self._call_exception_handler(exc, route)
if response_builder:
return response_builder
+ logger.exception(exc)
if self._debug:
# If the user has turned on debug mode,
# we'll let the original exception propagate so
@@ -874,8 +1197,12 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
# Add reference to parent ApiGatewayResolver to support use cases where people subclass it to add custom logic
router.api_resolver = self
- # Merge app and router context
+ logger.debug("Merging App context with Router context")
self.context.update(**router.context)
+
+ logger.debug("Appending Router middlewares into App middlewares.")
+ self._router_middlewares = self._router_middlewares + router._router_middlewares
+
# use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx)
router.context = self.context
@@ -887,7 +1214,15 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
rule = prefix if rule == "/" else f"{prefix}{rule}"
new_route = (rule, *route[1:])
- self.route(*new_route)(func)
+ # Middlewares are stored by route separately - must grab them to include
+ middlewares = router._routes_with_middleware.get(new_route)
+
+ # Need to use "type: ignore" here since mypy does not like a named parameter after
+ # tuple expansion since may cause duplicate named parameters in the function signature.
+ # In this case this is not possible since the tuple expansion is from a hashable source
+ # and the `middlewares` List is a non-hashable structure so will never be included.
+ # Still need to ignore for mypy checks or will cause failures (false-positive)
+ self.route(*new_route, middlewares=middlewares)(func) # type: ignore
class Router(BaseRouter):
@@ -895,6 +1230,7 @@ class Router(BaseRouter):
def __init__(self):
self._routes: Dict[tuple, Callable] = {}
+ self._routes_with_middleware: Dict[tuple, List[Callable]] = {}
self.api_resolver: Optional[BaseRouter] = None
self.context = {} # early init as customers might add context before event resolution
@@ -905,11 +1241,26 @@ def route(
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable[..., Any]]] = None,
):
def register_route(func: Callable):
# Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key
methods = (method,) if isinstance(method, str) else tuple(method)
- self._routes[(rule, methods, cors, compress, cache_control)] = func
+
+ route_key = (rule, methods, cors, compress, cache_control)
+
+ # Collate Middleware for routes
+ if middlewares is not None:
+ for handler in middlewares:
+ if self._routes_with_middleware.get(route_key) is None:
+ self._routes_with_middleware[route_key] = [handler]
+ else:
+ self._routes_with_middleware[route_key].append(handler)
+ else:
+ self._routes_with_middleware[route_key] = []
+
+ self._routes[route_key] = func
+
return func
return register_route
@@ -936,9 +1287,10 @@ def route(
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
+ middlewares: Optional[List[Callable[..., Any]]] = None,
):
# NOTE: see #1552 for more context.
- return super().route(rule.rstrip("/"), method, cors, compress, cache_control)
+ return super().route(rule.rstrip("/"), method, cors, compress, cache_control, middlewares)
# Override _compile_regex to exclude trailing slashes for route resolution
@staticmethod
diff --git a/aws_lambda_powertools/event_handler/middlewares/__init__.py b/aws_lambda_powertools/event_handler/middlewares/__init__.py
new file mode 100644
index 00000000000..068ce9c04b7
--- /dev/null
+++ b/aws_lambda_powertools/event_handler/middlewares/__init__.py
@@ -0,0 +1,3 @@
+from aws_lambda_powertools.event_handler.middlewares.base import BaseMiddlewareHandler, NextMiddleware
+
+__all__ = ["BaseMiddlewareHandler", "NextMiddleware"]
diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py
new file mode 100644
index 00000000000..32a4486bb31
--- /dev/null
+++ b/aws_lambda_powertools/event_handler/middlewares/base.py
@@ -0,0 +1,122 @@
+from abc import ABC, abstractmethod
+from typing import Generic
+
+from typing_extensions import Protocol
+
+from aws_lambda_powertools.event_handler.api_gateway import Response
+from aws_lambda_powertools.event_handler.types import EventHandlerInstance
+
+
+class NextMiddleware(Protocol):
+ def __call__(self, app: EventHandlerInstance) -> Response:
+ """Protocol for callback regardless of next_middleware(app), get_response(app) etc"""
+ ...
+
+ def __name__(self) -> str: # noqa A003
+ """Protocol for name of the Middleware"""
+ ...
+
+
+class BaseMiddlewareHandler(Generic[EventHandlerInstance], ABC):
+ """Base implementation for Middlewares to run code before and after in a chain.
+
+
+ This is the middleware handler function where middleware logic is implemented.
+ The next middleware handler is represented by `next_middleware`, returning a Response object.
+
+ Examples
+ --------
+
+ **Correlation ID Middleware**
+
+ ```python
+ import requests
+
+ from aws_lambda_powertools import Logger
+ from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
+ from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
+
+ app = APIGatewayRestResolver()
+ logger = Logger()
+
+
+ class CorrelationIdMiddleware(BaseMiddlewareHandler):
+ def __init__(self, header: str):
+ super().__init__()
+ self.header = header
+
+ def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
+ # BEFORE logic
+ request_id = app.current_event.request_context.request_id
+ correlation_id = app.current_event.get_header_value(
+ name=self.header,
+ default_value=request_id,
+ )
+
+ # Call next middleware or route handler ('/todos')
+ response = next_middleware(app)
+
+ # AFTER logic
+ response.headers[self.header] = correlation_id
+
+ return response
+
+
+ @app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")])
+ def get_todos():
+ todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
+ todos.raise_for_status()
+
+ # for brevity, we'll limit to the first 10 only
+ return {"todos": todos.json()[:10]}
+
+
+ @logger.inject_lambda_context
+ def lambda_handler(event, context):
+ return app.resolve(event, context)
+
+ ```
+
+ """
+
+ @abstractmethod
+ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
+ """
+ The Middleware Handler
+
+ Parameters
+ ----------
+ app: EventHandlerInstance
+ An instance of an Event Handler that implements ApiGatewayResolver
+ next_middleware: NextMiddleware
+ The next middleware handler in the chain
+
+ Returns
+ -------
+ Response
+ The response from the next middleware handler in the chain
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def __name__(self) -> str: # noqa A003
+ return str(self.__class__.__name__)
+
+ def __call__(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
+ """
+ The Middleware handler function.
+
+ Parameters
+ ----------
+ app: ApiGatewayResolver
+ An instance of an Event Handler that implements ApiGatewayResolver
+ next_middleware: NextMiddleware
+ The next middleware handler in the chain
+
+ Returns
+ -------
+ Response
+ The response from the next middleware handler in the chain
+ """
+ return self.handler(app, next_middleware)
diff --git a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py
new file mode 100644
index 00000000000..66be47a48f3
--- /dev/null
+++ b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py
@@ -0,0 +1,124 @@
+import logging
+from typing import Dict, Optional
+
+from aws_lambda_powertools.event_handler.api_gateway import Response
+from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError
+from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
+from aws_lambda_powertools.event_handler.types import EventHandlerInstance
+from aws_lambda_powertools.utilities.validation import validate
+from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError
+
+logger = logging.getLogger(__name__)
+
+
+class SchemaValidationMiddleware(BaseMiddlewareHandler):
+ """Middleware to validate API request and response against JSON Schema using the [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/).
+
+ Examples
+ --------
+ **Validating incoming event**
+
+ ```python
+ import requests
+
+ from aws_lambda_powertools import Logger
+ from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
+ from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
+ from aws_lambda_powertools.event_handler.middlewares.schema_validation import SchemaValidationMiddleware
+
+ app = APIGatewayRestResolver()
+ logger = Logger()
+ json_schema_validation = SchemaValidationMiddleware(inbound_schema=INCOMING_JSON_SCHEMA)
+
+
+ @app.get("/todos", middlewares=[json_schema_validation])
+ def get_todos():
+ todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
+ todos.raise_for_status()
+
+ # for brevity, we'll limit to the first 10 only
+ return {"todos": todos.json()[:10]}
+
+
+ @logger.inject_lambda_context
+ def lambda_handler(event, context):
+ return app.resolve(event, context)
+ ```
+ """
+
+ def __init__(
+ self,
+ inbound_schema: Dict,
+ inbound_formats: Optional[Dict] = None,
+ outbound_schema: Optional[Dict] = None,
+ outbound_formats: Optional[Dict] = None,
+ ):
+ """See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters.
+
+ Parameters
+ ----------
+ inbound_schema : Dict
+ JSON Schema to validate incoming event
+ inbound_formats : Optional[Dict], optional
+ Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
+ JSON Schema to validate outbound event, by default None
+ outbound_formats : Optional[Dict], optional
+ Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
+ """ # noqa: E501
+ super().__init__()
+ self.inbound_schema = inbound_schema
+ self.inbound_formats = inbound_formats
+ self.outbound_schema = outbound_schema
+ self.outbound_formats = outbound_formats
+
+ def bad_response(self, error: SchemaValidationError) -> Response:
+ message: str = f"Bad Response: {error.message}"
+ logger.debug(message)
+ raise BadRequestError(message)
+
+ def bad_request(self, error: SchemaValidationError) -> Response:
+ message: str = f"Bad Request: {error.message}"
+ logger.debug(message)
+ raise BadRequestError(message)
+
+ def bad_config(self, error: InvalidSchemaFormatError) -> Response:
+ logger.debug(f"Invalid Schema Format: {error}")
+ raise InternalServerError("Internal Server Error")
+
+ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
+ """Validates incoming JSON payload (body) against JSON Schema provided.
+
+ Parameters
+ ----------
+ app : EventHandlerInstance
+ An instance of an Event Handler
+ next_middleware : NextMiddleware
+ Callable to get response from the next middleware or route handler in the chain
+
+ Returns
+ -------
+ Response
+ It can return three types of response objects
+
+ - Original response: Propagates HTTP response returned from the next middleware if validation succeeds
+ - HTTP 400: Payload or response failed JSON Schema validation
+ - HTTP 500: JSON Schema provided has incorrect format
+ """
+ try:
+ validate(event=app.current_event.json_body, schema=self.inbound_schema, formats=self.inbound_formats)
+ except SchemaValidationError as error:
+ return self.bad_request(error)
+ except InvalidSchemaFormatError as error:
+ return self.bad_config(error)
+
+ result = next_middleware(app)
+
+ if self.outbound_formats is not None:
+ try:
+ validate(event=result.body, schema=self.inbound_schema, formats=self.inbound_formats)
+ except SchemaValidationError as error:
+ return self.bad_response(error)
+ except InvalidSchemaFormatError as error:
+ return self.bad_config(error)
+
+ return result
diff --git a/aws_lambda_powertools/event_handler/types.py b/aws_lambda_powertools/event_handler/types.py
new file mode 100644
index 00000000000..11cd146a57a
--- /dev/null
+++ b/aws_lambda_powertools/event_handler/types.py
@@ -0,0 +1,5 @@
+from typing import TypeVar
+
+from aws_lambda_powertools.event_handler import ApiGatewayResolver
+
+EventHandlerInstance = TypeVar("EventHandlerInstance", bound=ApiGatewayResolver)
diff --git a/aws_lambda_powertools/shared/types.py b/aws_lambda_powertools/shared/types.py
index e4e10192e55..b29c04cbe6b 100644
--- a/aws_lambda_powertools/shared/types.py
+++ b/aws_lambda_powertools/shared/types.py
@@ -1,5 +1,14 @@
+import sys
from typing import Any, Callable, Dict, List, TypeVar, Union
AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001
# JSON primitives only, mypy doesn't support recursive tho
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
+
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+__all__ = ["Protocol"]
diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py
index 7a3fc8ab404..fa7c5296042 100644
--- a/aws_lambda_powertools/utilities/data_classes/common.py
+++ b/aws_lambda_powertools/utilities/data_classes/common.py
@@ -1,7 +1,7 @@
import base64
import json
from collections.abc import Mapping
-from typing import Any, Callable, Dict, Iterator, List, Optional
+from typing import Any, Callable, Dict, Iterator, List, Optional, overload
from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -156,7 +156,24 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
default_value=default_value,
)
- # Maintenance: missing @overload to ensure return type is a str when default_value is set
+ @overload
+ def get_header_value(
+ self,
+ name: str,
+ default_value: str,
+ case_sensitive: Optional[bool] = False,
+ ) -> str:
+ ...
+
+ @overload
+ def get_header_value(
+ self,
+ name: str,
+ default_value: Optional[str] = None,
+ case_sensitive: Optional[bool] = False,
+ ) -> Optional[str]:
+ ...
+
def get_header_value(
self,
name: str,
diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
index ffa9cb263ab..35194f1f3f0 100644
--- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
+++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, overload
from aws_lambda_powertools.shared.headers_serializer import (
BaseHeadersSerializer,
@@ -91,6 +91,24 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
default_value=default_value,
)
+ @overload
+ def get_header_value(
+ self,
+ name: str,
+ default_value: str,
+ case_sensitive: Optional[bool] = False,
+ ) -> str:
+ ...
+
+ @overload
+ def get_header_value(
+ self,
+ name: str,
+ default_value: Optional[str] = None,
+ case_sensitive: Optional[bool] = False,
+ ) -> Optional[str]:
+ ...
+
def get_header_value(
self,
name: str,
diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md
index dcfa38f6f9a..dd249ec6650 100644
--- a/docs/core/event_handler/api_gateway.md
+++ b/docs/core/event_handler/api_gateway.md
@@ -10,6 +10,7 @@ Event handler for Amazon API Gateway REST and HTTP APIs, Application Loader Bala
* Lightweight routing to reduce boilerplate for API Gateway REST/HTTP API, ALB and Lambda Function URLs.
* Support for CORS, binary and Gzip compression, Decimals JSON encoding and bring your own JSON serializer
* Built-in integration with [Event Source Data Classes utilities](../../utilities/data_classes.md){target="_blank"} for self-documented event schema
+* Works with micro function (one or a few routes) and monolithic functions (all routes)
## Getting started
@@ -353,14 +354,226 @@ For convenience, these are the default values when using `CORSConfig` to enable
???+ tip "Multiple origins?"
If you need to allow multiple origins, pass the additional origins using the `extra_origins` key.
-| Key | Value | Note |
-| -------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| **[allow_origin](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank" rel="nofollow"}**: `str` | `*` | Only use the default value for development. **Never use `*` for production** unless your use case requires it |
-| **[extra_origins](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank" rel="nofollow"}**: `List[str]` | `[]` | Additional origins to be allowed, in addition to the one specified in `allow_origin` |
-| **[allow_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers){target="_blank" rel="nofollow"}**: `List[str]` | `[Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token]` | Additional headers will be appended to the default list for your convenience |
+| Key | Value | Note |
+| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| **[allow_origin](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank" rel="nofollow"}**: `str` | `*` | Only use the default value for development. **Never use `*` for production** unless your use case requires it |
+| **[extra_origins](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank" rel="nofollow"}**: `List[str]` | `[]` | Additional origins to be allowed, in addition to the one specified in `allow_origin` |
+| **[allow_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers){target="_blank" rel="nofollow"}**: `List[str]` | `[Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token]` | Additional headers will be appended to the default list for your convenience |
| **[expose_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers){target="_blank" rel="nofollow"}**: `List[str]` | `[]` | Any additional header beyond the [safe listed by CORS specification](https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_response_header){target="_blank" rel="nofollow"}. |
-| **[max_age](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age){target="_blank" rel="nofollow"}**: `int` | `` | Only for pre-flight requests if you choose to have your function to handle it instead of API Gateway |
-| **[allow_credentials](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials){target="_blank" rel="nofollow"}**: `bool` | `False` | Only necessary when you need to expose cookies, authorization headers or TLS client certificates. |
+| **[max_age](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age){target="_blank" rel="nofollow"}**: `int` | `` | Only for pre-flight requests if you choose to have your function to handle it instead of API Gateway |
+| **[allow_credentials](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials){target="_blank" rel="nofollow"}**: `bool` | `False` | Only necessary when you need to expose cookies, authorization headers or TLS client certificates. |
+
+### Middleware
+
+```mermaid
+stateDiagram
+ direction LR
+
+ EventHandler: GET /todo
+ Before: Before response
+ Next: next_middleware()
+ MiddlewareLoop: Middleware loop
+ AfterResponse: After response
+ MiddlewareFinished: Modified response
+ Response: Final response
+
+ EventHandler --> Middleware: Has middleware?
+ state MiddlewareLoop {
+ direction LR
+ Middleware --> Before
+ Before --> Next
+ Next --> Middleware: More middlewares?
+ Next --> AfterResponse
+ }
+ AfterResponse --> MiddlewareFinished
+ MiddlewareFinished --> Response
+ EventHandler --> Response: No middleware
+```
+
+A middleware is a function you register per route to **intercept** or **enrich** a **request before** or **after** any response.
+
+Each middleware function receives the following arguments:
+
+1. **app**. An Event Handler instance so you can access incoming request information, Lambda context, etc.
+2. **next_middleware**. A function to get the next middleware or route's response.
+
+Here's a sample middleware that extracts and injects correlation ID, using `APIGatewayRestResolver` (works for any [Resolver](#event-resolvers)):
+
+=== "middleware_getting_started.py"
+
+ ```python hl_lines="11 22 29" title="Your first middleware to extract and inject correlation ID"
+ --8<-- "examples/event_handler_rest/src/middleware_getting_started.py"
+ ```
+
+ 1. You can access current request like you normally would.
+ 2. [Shared context is available](#sharing-contextual-data) to any middleware, Router and App instances.
+ 3. Get response from the next middleware (if any) or from `/todos` route.
+ 4. You can manipulate headers, body, or status code before returning it.
+ 5. Register one or more middlewares in order of execution.
+
+=== "middleware_getting_started_output.json"
+
+ ```json hl_lines="9-10"
+ --8<-- "examples/event_handler_rest/src/middleware_getting_started_output.json"
+ ```
+
+#### Global middlewares
+
+
+
+You can use `app.use` to register middlewares that should always run regardless of the route, also known as global middlewares.
+
+Event Handler **calls global middlewares first**, then middlewares defined at the route level. Here's an example with both middlewares:
+
+=== "middleware_global_middlewares.py"
+
+ > Use [debug mode](#debug-mode) if you need to log request/response.
+
+ ```python hl_lines="10"
+ --8<-- "examples/event_handler_rest/src/middleware_global_middlewares.py"
+ ```
+
+ 1. A separate file where our middlewares are to keep this example focused.
+ 2. We register `log_request_response` as a global middleware to run before middleware.
+ ```mermaid
+ stateDiagram
+ direction LR
+
+ GlobalMiddleware: Log request response
+ RouteMiddleware: Inject correlation ID
+ EventHandler: Event Handler
+
+ EventHandler --> GlobalMiddleware
+ GlobalMiddleware --> RouteMiddleware
+ ```
+
+=== "middleware_global_middlewares_module.py"
+
+ ```python hl_lines="8"
+ --8<-- "examples/event_handler_rest/src/middleware_global_middlewares_module.py"
+ ```
+
+#### Returning early
+
+
+
+Imagine you want to stop processing a request if something is missing, or return immediately if you've seen this request before.
+
+In these scenarios, you short-circuit the middleware processing logic by returning a [Response object](#fine-grained-responses), or raising a [HTTP Error](#raising-http-errors). This signals to Event Handler to stop and run each `After` logic left in the chain all the way back.
+
+Here's an example where we prevent any request that doesn't include a correlation ID header:
+
+=== "middleware_early_return.py"
+
+ ```python hl_lines="12"
+ --8<-- "examples/event_handler_rest/src/middleware_early_return.py"
+ ```
+
+ 1. This middleware will raise an exception if correlation ID header is missing.
+ 2. This code section will not run if `enforce_correlation_id` returns early.
+
+=== "middleware_global_middlewares_module.py"
+
+ ```python hl_lines="35 38"
+ --8<-- "examples/event_handler_rest/src/middleware_global_middlewares_module.py"
+ ```
+
+ 1. Raising an exception OR returning a Response object early will short-circuit the middleware chain.
+
+=== "middleware_early_return_output.json"
+
+ ```python hl_lines="2-3"
+ --8<-- "examples/event_handler_rest/src/middleware_early_return_output.json"
+ ```
+
+#### Handling exceptions
+
+!!! tip "For catching exceptions more broadly, we recommend you use the [exception_handler](#exception-handling) decorator."
+
+By default, any unhandled exception in the middleware chain is eventually propagated as a HTTP 500 back to the client.
+
+While there isn't anything special on how to use [`try/catch`](https://docs.python.org/3/tutorial/errors.html#handling-exceptions){target="_blank" rel="nofollow"} for middlewares, it is important to visualize how Event Handler deals with them under the following scenarios:
+
+=== "Unhandled exception from route handler"
+
+ An exception wasn't caught by any middleware during `next_middleware()` block, therefore it propagates all the way back to the client as HTTP 500.
+
+
+ 
+ 
+
+ _Unhandled route exceptions propagate back to the client_
+
+
+=== "Route handler exception caught by a middleware"
+
+ An exception was only caught by the third middleware, resuming the normal execution of each `After` logic for the second and first middleware.
+
+
+ 
+ 
+
+ _Unhandled route exceptions propagate back to the client_
+
+
+=== "Middleware short-circuit by raising exception"
+
+ The third middleware short-circuited the chain by raising an exception and completely skipping the fourth middleware. Because we only caught it in the first middleware, it skipped the `After` logic in the second middleware.
+
+