From 2b9559864b772757f1f4a0f7f0b4f08272829f3b Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 10:23:37 +0100 Subject: [PATCH 01/45] initial streamable http server --- .../servers/simple-streamablehttp/README.md | 33 ++ .../mcp_simple_streamablehttp/__init__.py | 0 .../mcp_simple_streamablehttp/__main__.py | 4 + .../mcp_simple_streamablehttp/server.py | 167 +++++++ .../simple-streamablehttp/pyproject.toml | 47 ++ src/mcp/server/session.py | 18 +- src/mcp/server/streamableHttp.py | 415 ++++++++++++++++++ src/mcp/shared/session.py | 9 +- src/mcp/types.py | 1 + uv.lock | 41 +- 10 files changed, 727 insertions(+), 8 deletions(-) create mode 100644 examples/servers/simple-streamablehttp/README.md create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py create mode 100644 examples/servers/simple-streamablehttp/pyproject.toml create mode 100644 src/mcp/server/streamableHttp.py diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md new file mode 100644 index 00000000..aa5e707a --- /dev/null +++ b/examples/servers/simple-streamablehttp/README.md @@ -0,0 +1,33 @@ +# MCP Simple StreamableHttp Server Example + +A simple MCP server example demonstrating the StreamableHttp transport, which enables HTTP-based communication with MCP servers using streaming. + +## Features + +- Uses the StreamableHTTP transport for server-client communication +- Task management with anyio task groups +- Ability to send multiple notifications over time to the client +- Proper resource cleanup and lifespan management + +## Usage + +Start the server on the default or custom port: + +```bash + +# Using custom port +uv run mcp-simple-streamablehttp --port 3000 + +# Custom logging level +uv run mcp-simple-streamablehttp --log-level DEBUG +``` + +The server exposes a tool named "start-notification-stream" that accepts three arguments: + +- `interval`: Time between notifications in seconds (e.g., 1.0) +- `count`: Number of notifications to send (e.g., 5) +- `caller`: Identifier string for the caller + +## Client + +You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector] \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py new file mode 100644 index 00000000..a6876bf9 --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py @@ -0,0 +1,4 @@ +from .server import main + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py new file mode 100644 index 00000000..19a83790 --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -0,0 +1,167 @@ +import contextlib +import logging +from uuid import uuid4 + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamableHttp import StreamableHTTPServerTransport +from starlette.applications import Starlette +from starlette.routing import Mount + +# Configure logging +logger = logging.getLogger(__name__) + +# Global task group that will be initialized in the lifespan +task_group = None + + +@contextlib.asynccontextmanager +async def lifespan(app): + """Application lifespan context manager for managing task group.""" + global task_group + + async with anyio.create_task_group() as tg: + task_group = tg + logger.info("Application started, task group initialized!") + try: + yield + finally: + logger.info("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + logger.info("Resources cleaned up successfully.") + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +def main( + port: int, + log_level: str, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server("mcp-streamable-http-demo") + + @app.call_tool() + async def call_tool( + name: str, arguments: dict + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ctx = app.request_context + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i+1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return [ + types.TextContent( + type="text", + text=( + f"Sent {count} notifications with {interval}s interval" + f" for caller: {caller}" + ), + ) + ] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="start-notification-stream", + description=( + "Sends a stream of notifications with configurable count" + " and interval" + ), + inputSchema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ( + "Identifier of the caller to include in notifications" + ), + }, + }, + }, + ) + ] + + # Create a Streamable HTTP transport + http_transport = StreamableHTTPServerTransport( + mcp_session_id=uuid4().hex, + ) + + # We need to store the server instances between requests + server_instances = {} + + # ASGI handler for streamable HTTP connections + async def handle_streamable_http(scope, receive, send): + if http_transport.mcp_session_id in server_instances: + logger.debug("Session already exists, handling request directly") + await http_transport.handle_request(scope, receive, send) + else: + # Start new server instance for this session + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + await app.run( + read_stream, write_stream, app.create_initialization_options() + ) + + if not task_group: + raise RuntimeError("Task group is not initialized") + + task_group.start_soon(run_server) + + # For initialization requests, store the server reference + if http_transport.mcp_session_id: + server_instances[http_transport.mcp_session_id] = True + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + + # Create an ASGI application using the transport + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + + return 0 diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml new file mode 100644 index 00000000..de43bd2f --- /dev/null +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +description = "A simple MCP server exposing a website fetching tool with StreamableHttp transport" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +maintainers = [ + { name = "David Soria Parra", email = "davidsp@anthropic.com" }, + { name = "Justin Spahr-Summers", email = "justin@anthropic.com" }, +] +keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-streamablehttp = "mcp_simple_streamablehttp.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_streamablehttp"] + +[tool.pyright] +include = ["mcp_simple_streamablehttp"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 568ecd4b..3a1f210d 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -179,7 +179,11 @@ async def _received_notification( ) async def send_log_message( - self, level: types.LoggingLevel, data: Any, logger: str | None = None + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, ) -> None: """Send a log message notification.""" await self.send_notification( @@ -192,7 +196,8 @@ async def send_log_message( logger=logger, ), ) - ) + ), + related_request_id, ) async def send_resource_updated(self, uri: AnyUrl) -> None: @@ -261,7 +266,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + related_request_id: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -274,7 +283,8 @@ async def send_progress_notification( total=total, ), ) - ) + ), + related_request_id, ) async def send_resource_list_changed(self) -> None: diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py new file mode 100644 index 00000000..cfce6629 --- /dev/null +++ b/src/mcp/server/streamableHttp.py @@ -0,0 +1,415 @@ +""" +StreamableHTTP Server Transport Module + +This module implements an HTTP transport layer with Streamable HTTP. + +The transport handles bidirectional communication using HTTP requests and +responses, with streaming support for long-running operations. +""" + +import json +import logging +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.types import ( + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, +) + +logger = logging.getLogger(__name__) + +# Maximum size for incoming messages +MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB + + +class StreamableHTTPServerTransport: + """ + HTTP server transport with event streaming support for MCP. + + Handles POST requests containing JSON-RPC messages and provides + Server-Sent Events (SSE) responses for streaming communication. + """ + + # Server notification streams for POST requests as well as standalone SSE stream + _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None + _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + # Dictionary to track request-specific message streams + _request_streams: dict[str, MemoryObjectSendStream[JSONRPCMessage]] + + def __init__( + self, + mcp_session_id: str | None, + ): + """ + Initialize a new StreamableHTTP server transport. + + Args: + mcp_session_id: Optional session identifier for this connection + """ + self.mcp_session_id = mcp_session_id + self._request_streams = {} + + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + ASGI application entry point that handles all HTTP requests + + Args: + stream_id: Unique identifier for this stream + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + request = Request(scope, receive) + + if request.method == "POST": + await self._handle_post_request(scope, request, receive, send) + elif request.method == "GET": + await self._handle_get_request(request, send) + elif request.method == "DELETE": + await self._handle_delete_request(request, send) + else: + await self._handle_unsupported_request(send) + + async def _handle_post_request( + self, scope: Scope, request: Request, receive: Receive, send: Send + ) -> None: + """ + Handles POST requests containing JSON-RPC messages + + Args: + stream_id: Unique identifier for this stream + scope: ASGI scope + request: Starlette Request object + receive: ASGI receive function + send: ASGI send function + """ + body = await request.body() + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + return + try: + # Validate Accept header + accept_header = request.headers.get("accept", "") + if ( + "application/json" not in accept_header + or "text/event-stream" not in accept_header + ): + response = Response( + ( + "Not Acceptable: Client must accept both application/json and " + "text/event-stream" + ), + status_code=406, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + + # Validate Content-Type + content_type = request.headers.get("content-type", "") + if "application/json" not in content_type: + response = Response( + "Unsupported Media Type: Content-Type must be application/json", + status_code=415, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + + # Parse the body + body = await request.body() + if len(body) > MAXIMUM_MESSAGE_SIZE: + response = Response( + "Payload Too Large: Message exceeds maximum size", + status_code=413, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + + try: + raw_message = json.loads(body) + except json.JSONDecodeError as e: + response = Response( + f"Parse error: {str(e)}", + status_code=400, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + message = None + try: + message = JSONRPCMessage.model_validate(raw_message) + except ValidationError as e: + response = Response( + f"Validation error: {str(e)}", + status_code=400, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + if not message: + response = Response( + "Invalid Request: Message is empty", + status_code=400, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + + # Check if this is an initialization request + is_initialization_request = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + + if is_initialization_request: + # TODO validate + logger.info("INITIALIZATION REQUEST") + # For non-initialization requests, validate the session + elif not await self._validate_session(request, send): + return + + is_request = isinstance(message.root, JSONRPCRequest) + + # For notifications and responses only, return 202 Accepted + if not is_request: + headers: dict[str, str] = {} + if self.mcp_session_id: + headers["mcp-session-id"] = self.mcp_session_id + + # Create response object and send it + response = Response("Accepted", status_code=202, headers=headers) + await response(scope, receive, send) + + # Process the message after sending the response + await writer.send(message) + + return + + # For requests, set up an SSE stream for the response + if is_request: + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + } + + if self.mcp_session_id: + headers["mcp-session-id"] = self.mcp_session_id + + # For SSE responses, set up SSE stream + headers["Content-Type"] = "text/event-stream" + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, Any]](0) + ) + + async def sse_writer(): + try: + # Create a request-specific message stream for this POST request + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Get the request ID from the incoming request message + request_id = None + if isinstance(message.root, JSONRPCRequest): + request_id = str(message.root.id) + # Register this stream for the request ID + if request_id: + self._request_streams[request_id] = ( + request_stream_writer + ) + + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for received_message in request_stream_reader: + # Send the message via SSE + related_request_id = None + + if isinstance( + received_message.root, JSONRPCNotification + ): + # Get related_request_id from params + params = received_message.root.params + if params and "related_request_id" in params: + related_request_id = params.get( + "related_request_id" + ) + logger.debug( + f"NOTIFICATION: {related_request_id}, " + f"{params.get('data')}" + ) + + # Build the event data + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance(received_message.root, JSONRPCResponse): + if request_id: + self._request_streams.pop(request_id, None) + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + # TODO + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Extract the request ID outside the try block for proper scope + outer_request_id = None + if isinstance(message.root, JSONRPCRequest): + outer_request_id = str(message.root.id) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + + # Then send the message to be processed by the server + await writer.send(message) + except Exception: + logger.exception("SSE response error") + # Make sure to clean up the request stream if something goes wrong + if outer_request_id and outer_request_id in self._request_streams: + self._request_streams.pop(outer_request_id, None) + + except Exception as err: + logger.exception("Error handling POST request") + response = Response(f"Error handling POST request: {err}", status_code=500) + await response(scope, receive, send) + if writer: + await writer.send(err) + return + + async def _handle_get_request(self, request: Request, send: Send) -> None: + pass + + async def _handle_delete_request(self, request: Request, send: Send) -> None: + pass + + async def _handle_unsupported_request(self, send: Send) -> None: + pass + + async def _validate_session(self, request: Request, send: Send) -> bool: + # TODO + return True + + @asynccontextmanager + async def connect( + self, + ) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ], + None, + ]: + """ + Context manager that provides read and write streams for a connection + + Yields: + Tuple of (read_stream, write_stream) for bidirectional communication + """ + + # Create the memory streams for this connection + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + # Store the streams + self._read_stream_writer = read_stream_writer + self._write_stream_reader = write_stream_reader + + # Start a task group for message routing + async with anyio.create_task_group() as tg: + # Create a message router that distributes messages to request streams + async def message_router(): + try: + async for message in write_stream_reader: + # Determine which request stream(s) should receive this message + target_request_id = None + + # For responses, route based on the request ID + if isinstance(message.root, JSONRPCResponse): + target_request_id = str(message.root.id) + # For notifications, route by related_request_id if available + elif isinstance(message.root, JSONRPCNotification): + # Get related_request_id from params + params = message.root.params + if params and "related_request_id" in params: + related_id = params.get("related_request_id") + if related_id is not None: + target_request_id = str(related_id) + + # Send to the specific request stream if available + if ( + target_request_id + and target_request_id in self._request_streams + ): + try: + await self._request_streams[target_request_id].send( + message + ) + except ( + anyio.BrokenResourceError, + anyio.ClosedResourceError, + ): + # Stream might be closed, remove from registry + self._request_streams.pop(target_request_id, None) + except Exception as e: + logger.exception(f"Error in message router: {e}") + + # Start the message router + tg.start_soon(message_router) + + try: + # Yield the streams for the caller to use + yield read_stream, write_stream + finally: + # Clean up any remaining request streams + for stream in list(self._request_streams.values()): + try: + await stream.aclose() + except Exception: + pass + self._request_streams.clear() diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce3..1017bb98 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -267,16 +267,21 @@ async def send_request( else: return result_type.model_validate(response_or_error.result) - async def send_notification(self, notification: SendNotificationT) -> None: + async def send_notification( + self, + notification: SendNotificationT, + related_request_id: RequestId | None = None, + ) -> None: """ Emits a notification, which is a one-way message that does not expect a response. """ + if related_request_id is not None and notification.root.params is not None: + notification.root.params.related_request_id = related_request_id jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) async def _send_response( diff --git a/src/mcp/types.py b/src/mcp/types.py index bd71d51f..30500e31 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -58,6 +58,7 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) + related_request_id: RequestId | None = None """ This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. diff --git a/uv.lock b/uv.lock index 78f46f47..65439e5c 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" [options] @@ -10,6 +9,7 @@ members = [ "mcp", "mcp-simple-prompt", "mcp-simple-resource", + "mcp-simple-streamablehttp", "mcp-simple-tool", ] @@ -487,6 +487,7 @@ wheels = [ [[package]] name = "mcp" +version = "1.6.1.dev12+70115b9" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -543,7 +544,6 @@ requires-dist = [ { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] -provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ @@ -628,6 +628,43 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +source = { editable = "examples/servers/simple-streamablehttp" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0" From 3d790f8979bfd43d505151e024433b533376946b Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 12:17:33 +0100 Subject: [PATCH 02/45] add request validation and tests --- src/mcp/server/streamableHttp.py | 270 ++++++++++++++++---- tests/server/test_streamableHttp.py | 378 ++++++++++++++++++++++++++++ 2 files changed, 601 insertions(+), 47 deletions(-) create mode 100644 tests/server/test_streamableHttp.py diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index cfce6629..e65c6c46 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -11,6 +11,7 @@ import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from http import HTTPStatus from typing import Any import anyio @@ -33,6 +34,14 @@ # Maximum size for incoming messages MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + class StreamableHTTPServerTransport: """ @@ -61,6 +70,34 @@ def __init__( self.mcp_session_id = mcp_session_id self._request_streams = {} + def _create_error_response( + self, + message: str, + status_code: HTTPStatus, + headers: dict[str, str] | None = None, + ) -> Response: + """ + Create a standardized error response. + """ + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + message, + status_code=status_code, + headers=response_headers, + ) + + def _get_session_id(self, request: Request) -> str | None: + """ + Extract the session ID from request headers. + """ + return request.headers.get(MCP_SESSION_ID_HEADER) + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """ ASGI application entry point that handles all HTTP requests @@ -80,7 +117,46 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No elif request.method == "DELETE": await self._handle_delete_request(request, send) else: - await self._handle_unsupported_request(send) + await self._handle_unsupported_request(request, send) + + def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: + """ + Check if the request accepts the required media types. + + Args: + request: The HTTP request + + Returns: + Tuple of (has_json, has_sse) indicating whether each media type is accepted + """ + accept_header = request.headers.get("accept", "") + accept_types = [media_type.strip() for media_type in accept_header.split(",")] + + has_json = any( + media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types + ) + has_sse = any( + media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types + ) + + return has_json, has_sse + + def _check_content_type(self, request: Request) -> bool: + """ + Check if the request has the correct Content-Type. + + Args: + request: The HTTP request + + Returns: + True if Content-Type is acceptable, False otherwise + """ + content_type = request.headers.get("content-type", "") + content_type_parts = [ + part.strip() for part in content_type.split(";")[0].split(",") + ] + + return any(part == CONTENT_TYPE_JSON for part in content_type_parts) async def _handle_post_request( self, scope: Scope, request: Request, receive: Receive, send: Send @@ -89,13 +165,11 @@ async def _handle_post_request( Handles POST requests containing JSON-RPC messages Args: - stream_id: Unique identifier for this stream scope: ASGI scope request: Starlette Request object receive: ASGI receive function send: ASGI send function """ - body = await request.body() writer = self._read_stream_writer if writer is None: raise ValueError( @@ -103,41 +177,34 @@ async def _handle_post_request( ) return try: - # Validate Accept header - accept_header = request.headers.get("accept", "") - if ( - "application/json" not in accept_header - or "text/event-stream" not in accept_header - ): - response = Response( + # Check Accept headers + has_json, has_sse = self._check_accept_headers(request) + if not (has_json and has_sse): + response = self._create_error_response( ( "Not Acceptable: Client must accept both application/json and " "text/event-stream" ), - status_code=406, - headers={"Content-Type": "application/json"}, + HTTPStatus.NOT_ACCEPTABLE, ) await response(scope, receive, send) return # Validate Content-Type - content_type = request.headers.get("content-type", "") - if "application/json" not in content_type: - response = Response( + if not self._check_content_type(request): + response = self._create_error_response( "Unsupported Media Type: Content-Type must be application/json", - status_code=415, - headers={"Content-Type": "application/json"}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, ) await response(scope, receive, send) return - # Parse the body + # Parse the body - only read it once body = await request.body() if len(body) > MAXIMUM_MESSAGE_SIZE: - response = Response( + response = self._create_error_response( "Payload Too Large: Message exceeds maximum size", - status_code=413, - headers={"Content-Type": "application/json"}, + HTTPStatus.REQUEST_ENTITY_TOO_LARGE, ) await response(scope, receive, send) return @@ -145,29 +212,28 @@ async def _handle_post_request( try: raw_message = json.loads(body) except json.JSONDecodeError as e: - response = Response( + response = self._create_error_response( f"Parse error: {str(e)}", - status_code=400, - headers={"Content-Type": "application/json"}, + HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return + message = None try: message = JSONRPCMessage.model_validate(raw_message) except ValidationError as e: - response = Response( + response = self._create_error_response( f"Validation error: {str(e)}", - status_code=400, - headers={"Content-Type": "application/json"}, + HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return + if not message: - response = Response( + response = self._create_error_response( "Invalid Request: Message is empty", - status_code=400, - headers={"Content-Type": "application/json"}, + HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return @@ -179,8 +245,19 @@ async def _handle_post_request( ) if is_initialization_request: - # TODO validate - logger.info("INITIALIZATION REQUEST") + # Check if the server already has an established session + if self.mcp_session_id: + # Check if request has a session ID + request_session_id = self._get_session_id(request) + + # If request has a session ID but doesn't match, return 404 + if request_session_id and request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return # For non-initialization requests, validate the session elif not await self._validate_session(request, send): return @@ -189,12 +266,11 @@ async def _handle_post_request( # For notifications and responses only, return 202 Accepted if not is_request: - headers: dict[str, str] = {} - if self.mcp_session_id: - headers["mcp-session-id"] = self.mcp_session_id - # Create response object and send it - response = Response("Accepted", status_code=202, headers=headers) + response = self._create_error_response( + "Accepted", + HTTPStatus.ACCEPTED, + ) await response(scope, receive, send) # Process the message after sending the response @@ -208,13 +284,11 @@ async def _handle_post_request( headers = { "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, } if self.mcp_session_id: - headers["mcp-session-id"] = self.mcp_session_id - - # For SSE responses, set up SSE stream - headers["Content-Type"] = "text/event-stream" + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Create SSE stream sse_stream_writer, sse_stream_reader = ( anyio.create_memory_object_stream[dict[str, Any]](0) @@ -306,23 +380,125 @@ async def sse_writer(): except Exception as err: logger.exception("Error handling POST request") - response = Response(f"Error handling POST request: {err}", status_code=500) + response = self._create_error_response( + f"Error handling POST request: {err}", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) await response(scope, receive, send) if writer: await writer.send(err) return async def _handle_get_request(self, request: Request, send: Send) -> None: - pass + """ + Handle GET requests for SSE stream establishment + + Args: + request: The HTTP request + send: ASGI send function + """ + # Validate session ID if server has one + if not await self._validate_session(request, send): + return + + # Validate Accept header - must include text/event-stream + _, has_sse = self._check_accept_headers(request) + + if not has_sse: + response = self._create_error_response( + "Not Acceptable: Client must accept text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(request.scope, request.receive, send) + return + + # TODO: Implement SSE stream for GET requests + # For now, return 501 Not Implemented + response = self._create_error_response( + "SSE stream from GET request not implemented yet", + HTTPStatus.NOT_IMPLEMENTED, + ) + await response(request.scope, request.receive, send) async def _handle_delete_request(self, request: Request, send: Send) -> None: - pass + """ + Handle DELETE requests for explicit session termination + + Args: + request: The HTTP request + send: ASGI send function + """ + # Validate session ID + if not self.mcp_session_id: + # If no session ID set, return Method Not Allowed + response = self._create_error_response( + "Method Not Allowed: Session termination not supported", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + return + if not await self._validate_session(request, send): + return + # TODO : Implement session termination logic - async def _handle_unsupported_request(self, send: Send) -> None: - pass + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: + """ + Handle unsupported HTTP methods + + Args: + request: The HTTP request + send: ASGI send function + """ + headers = { + "Content-Type": CONTENT_TYPE_JSON, + "Allow": "GET, POST, DELETE", + } + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + response = Response( + "Method Not Allowed", + status_code=HTTPStatus.METHOD_NOT_ALLOWED, + headers=headers, + ) + await response(request.scope, request.receive, send) async def _validate_session(self, request: Request, send: Send) -> bool: - # TODO + """ + Validate the session ID in the request. + + Args: + request: The HTTP request + send: ASGI send function + + Returns: + bool: True if session is valid, False otherwise + """ + if not self.mcp_session_id: + # If we're not using session IDs, return True + return True + + # Get the session ID from the request headers + request_session_id = self._get_session_id(request) + + # If no session ID provided but required, return error + if not request_session_id: + response = self._create_error_response( + "Bad Request: Missing session ID", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + # If session ID doesn't match, return error + if request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + return True @asynccontextmanager diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py new file mode 100644 index 00000000..6296f22c --- /dev/null +++ b/tests/server/test_streamableHttp.py @@ -0,0 +1,378 @@ +""" +Tests for the StreamableHTTP server transport validation. + +This file contains tests for request validation in the StreamableHTTP transport. +""" + +import socket +import time +from collections.abc import Generator +from multiprocessing import Process + +import anyio +import pytest +import requests +import uvicorn +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route + +from mcp.server.streamableHttp import ( + MCP_SESSION_ID_HEADER, + StreamableHTTPServerTransport, +) +from mcp.types import JSONRPCMessage + +# Test constants +SERVER_NAME = "test_streamable_http_server" +TEST_SESSION_ID = "test-session-id-12345" + + +# App handler class for testing validation (not a pytest test class) +class StreamableAppHandler: + def __init__(self, session_id=None): + self.transport = StreamableHTTPServerTransport(mcp_session_id=session_id) + self.started = False + self.read_stream = None + self.write_stream = None + + async def startup(self): + """Initialize the transport streams.""" + # Create real memory streams to satisfy type checking + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + # Assign the streams to the transport + self.transport._read_stream_writer = read_stream_writer + self.transport._write_stream_reader = write_stream_reader + + # Store the streams so they don't get garbage collected + self.read_stream = read_stream + self.write_stream = write_stream + + self.started = True + print("Transport streams initialized") + + async def handle_request(self, request: Request): + """Handle incoming requests by validating and responding.""" + # Make sure transport is initialized + if not self.started: + await self.startup() + + # Let the transport handle the request validation and response + try: + await self.transport.handle_request( + request.scope, request.receive, request._send + ) + except Exception as e: + print(f"Error handling request: {e}") + # Make sure we provide an error response + response = Response( + status_code=500, + content=f"Server error: {str(e)}", + media_type="text/plain", + ) + await response(request.scope, request.receive, request._send) + + +@pytest.fixture +def server_port() -> int: + """Find an available port for the test server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + """Get the URL for the test server.""" + return f"http://127.0.0.1:{server_port}" + + +def create_app(session_id=None) -> Starlette: + """Create a Starlette application for testing.""" + # Create our test app handler + app_handler = StreamableAppHandler(session_id=session_id) + + # Define a startup event to ensure the transport is initialized + async def on_startup(): + """Initialize the transport on application startup.""" + print("Initializing transport streams...") + await app_handler.startup() + app_handler.started = True + print("Transport initialized") + + app = Starlette( + debug=True, # Enable debug mode for better error messages + routes=[ + Route( + "/mcp", + endpoint=app_handler.handle_request, + methods=["GET", "POST", "DELETE"], + ), + ], + on_startup=[on_startup], + ) + + return app + + +def run_server(port: int, session_id=None) -> None: + """Run the test server.""" + print(f"Starting test server on port {port} with session_id={session_id}") + + # Create app with simpler configuration + app = create_app(session_id) + + # Configure to use a single worker and simpler settings + config = uvicorn.Config( + app=app, + host="127.0.0.1", + port=port, + log_level="info", # Use info to see startup messages + limit_concurrency=10, + timeout_keep_alive=2, + access_log=False, + ) + + # Start the server + server = uvicorn.Server(config=config) + + # This is important to catch exceptions and prevent test hangs + try: + print("Server starting...") + server.run() + except Exception as e: + print(f"ERROR: Server failed to run: {e}") + import traceback + + traceback.print_exc() + + print("Server shutdown") + + +@pytest.fixture +def basic_server(server_port: int) -> Generator[None, None, None]: + """Start a basic server without session ID.""" + # Start server process + process = Process(target=run_server, kwargs={"port": server_port}, daemon=True) + process.start() + + # Wait for server to start + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + # Clean up + process.terminate() + process.join(timeout=1) + if process.is_alive(): + process.kill() + + +@pytest.fixture +def session_server(server_port: int) -> Generator[str, None, None]: + """Start a server with session ID.""" + # Start server process + process = Process( + target=run_server, + kwargs={"port": server_port, "session_id": TEST_SESSION_ID}, + daemon=True, + ) + process.start() + + # Wait for server to start + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield TEST_SESSION_ID + + # Clean up + process.terminate() + process.join(timeout=1) + if process.is_alive(): + process.kill() + + +# Basic request validation tests +def test_accept_header_validation(basic_server, server_url): + """Test that Accept header is properly validated.""" + # Test without Accept header + response = requests.post( + f"{server_url}/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with only application/json + response = requests.post( + f"{server_url}/mcp", + headers={"Accept": "application/json", "Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + + # Test with only text/event-stream + response = requests.post( + f"{server_url}/mcp", + headers={"Accept": "text/event-stream", "Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + + +def test_content_type_validation(basic_server, server_url): + """Test that Content-Type header is properly validated.""" + # Test with incorrect Content-Type + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + data="This is not JSON", + ) + assert response.status_code == 415 + assert "Unsupported Media Type" in response.text + + +def test_json_validation(basic_server, server_url): + """Test that JSON content is properly validated.""" + # Test with invalid JSON + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + data="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text + + # Test with valid JSON but invalid JSON-RPC + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text + + +def test_method_not_allowed(basic_server, server_url): + """Test that unsupported HTTP methods are rejected.""" + # Test with unsupported method (PUT) + response = requests.put( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text + + +def test_get_request_validation(basic_server, server_url): + """Test GET request validation for SSE streams.""" + # Test GET without Accept header + response = requests.get(f"{server_url}/mcp") + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test GET with wrong Accept header + response = requests.get( + f"{server_url}/mcp", + headers={"Accept": "application/json"}, + ) + assert response.status_code == 406 + + +def test_session_validation(session_server, server_url): + """Test session ID validation.""" + # session_id not used directly in this test + + # Test without session ID + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text + + # Test with invalid session ID + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: "invalid-session-id", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 404 + assert "Invalid or expired session ID" in response.text + + +def test_delete_request(session_server, server_url): + """Test DELETE request for session termination.""" + # session_id not used directly in this test + + # Test without session ID + response = requests.delete(f"{server_url}/mcp") + assert response.status_code == 400 + assert "Missing session ID" in response.text + + # Test with invalid session ID + response = requests.delete( + f"{server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: "invalid-session-id"}, + ) + assert response.status_code == 404 + assert "Invalid or expired session ID" in response.text + + +def test_delete_without_session_support(basic_server, server_url): + """Test DELETE request when server doesn't support sessions.""" + # Server without session support should reject DELETE + response = requests.delete(f"{server_url}/mcp") + assert response.status_code == 405 + assert "Method Not Allowed" in response.text From 27bc01ec4bb63f398316ed7648dbc99108e0176f Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 14:30:35 +0100 Subject: [PATCH 03/45] session management --- .../mcp_simple_streamablehttp/server.py | 78 ++++++++++++------- src/mcp/server/streamableHttp.py | 21 ++++- tests/server/test_streamableHttp.py | 60 +++++++++++++- 3 files changed, 128 insertions(+), 31 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 19a83790..3dc972b7 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,13 +1,19 @@ import contextlib import logging +from http import HTTPStatus from uuid import uuid4 import anyio import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamableHttp import StreamableHTTPServerTransport +from mcp.server.streamableHttp import ( + MCP_SESSION_ID_HEADER, + StreamableHTTPServerTransport, +) from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response from starlette.routing import Mount # Configure logging @@ -116,40 +122,56 @@ async def list_tools() -> list[types.Tool]: ) ] - # Create a Streamable HTTP transport - http_transport = StreamableHTTPServerTransport( - mcp_session_id=uuid4().hex, - ) - # We need to store the server instances between requests server_instances = {} + # Lock to prevent race conditions when creating new sessions + session_creation_lock = anyio.Lock() # ASGI handler for streamable HTTP connections async def handle_streamable_http(scope, receive, send): - if http_transport.mcp_session_id in server_instances: + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + if ( + request_mcp_session_id is not None + and request_mcp_session_id in server_instances + ): + transport = server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") - await http_transport.handle_request(scope, receive, send) + await transport.handle_request(scope, receive, send) + elif request_mcp_session_id is None: + # try to establish new session + logger.debug("Creating new transport") + # Use lock to prevent race conditions when creating new sessions + async with session_creation_lock: + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + ) + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) + + if not task_group: + raise RuntimeError("Task group is not initialized") + + # Store the instance before starting the task to prevent races + server_instances[http_transport.mcp_session_id] = http_transport + task_group.start_soon(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) else: - # Start new server instance for this session - async with http_transport.connect() as streams: - read_stream, write_stream = streams - - async def run_server(): - await app.run( - read_stream, write_stream, app.create_initialization_options() - ) - - if not task_group: - raise RuntimeError("Task group is not initialized") - - task_group.start_soon(run_server) - - # For initialization requests, store the server reference - if http_transport.mcp_session_id: - server_instances[http_transport.mcp_session_id] = True - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) # Create an ASGI application using the transport starlette_app = Starlette( diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index e65c6c46..8b1498d7 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -9,6 +9,7 @@ import json import logging +import re from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from http import HTTPStatus @@ -42,6 +43,10 @@ CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_SSE = "text/event-stream" +# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) +# Pattern ensures entire string contains only valid characters by using ^ and $ anchors +SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") + class StreamableHTTPServerTransport: """ @@ -65,8 +70,20 @@ def __init__( Initialize a new StreamableHTTP server transport. Args: - mcp_session_id: Optional session identifier for this connection + mcp_session_id: Optional session identifier for this connection. + Must contain only visible ASCII characters (0x21-0x7E). + + Raises: + ValueError: If the session ID contains invalid characters. """ + if mcp_session_id is not None and ( + not SESSION_ID_PATTERN.match(mcp_session_id) or + SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None + ): + raise ValueError( + "Session ID must only contain visible ASCII characters (0x21-0x7E)" + ) + self.mcp_session_id = mcp_session_id self._request_streams = {} @@ -439,7 +456,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: return if not await self._validate_session(request, send): return - # TODO : Implement session termination logic + # TODO : Implement session termination logic async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """ diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 6296f22c..bf0128d1 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -13,7 +13,6 @@ import pytest import requests import uvicorn -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -21,6 +20,7 @@ from mcp.server.streamableHttp import ( MCP_SESSION_ID_HEADER, + SESSION_ID_PATTERN, StreamableHTTPServerTransport, ) from mcp.types import JSONRPCMessage @@ -376,3 +376,61 @@ def test_delete_without_session_support(basic_server, server_url): response = requests.delete(f"{server_url}/mcp") assert response.status_code == 405 assert "Method Not Allowed" in response.text + + +def test_session_id_pattern(): + """Test that SESSION_ID_PATTERN correctly validates session IDs.""" + # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) + valid_session_ids = [ + "test-session-id", + "1234567890", + "session!@#$%^&*()_+-=[]{}|;:,.<>?/", + "~`", + ] + + for session_id in valid_session_ids: + assert SESSION_ID_PATTERN.match(session_id) is not None + # Ensure fullmatch matches too (whole string) + assert SESSION_ID_PATTERN.fullmatch(session_id) is not None + + # Invalid session IDs + invalid_session_ids = [ + "", # Empty string + " test", # Space (0x20) + "test\t", # Tab + "test\n", # Newline + "test\r", # Carriage return + "test" + chr(0x7F), # DEL character + "test" + chr(0x80), # Extended ASCII + "test" + chr(0x00), # Null character + "test" + chr(0x20), # Space (0x20) + ] + + for session_id in invalid_session_ids: + # For invalid IDs, either match will fail or fullmatch will fail + if SESSION_ID_PATTERN.match(session_id) is not None: + # If match succeeds, fullmatch should fail (partial match case) + assert SESSION_ID_PATTERN.fullmatch(session_id) is None + + +def test_streamable_http_transport_init_validation(): + """Test that StreamableHTTPServerTransport validates session ID on initialization.""" + # Valid session ID should initialize without errors + valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") + assert valid_transport.mcp_session_id == "valid-id" + + # None should be accepted + none_transport = StreamableHTTPServerTransport(mcp_session_id=None) + assert none_transport.mcp_session_id is None + + # Invalid session ID should raise ValueError + with pytest.raises(ValueError) as excinfo: + StreamableHTTPServerTransport(mcp_session_id="invalid id with space") + assert "Session ID must only contain visible ASCII characters" in str(excinfo.value) + + # Test with control characters + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\nid") + + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\n") From 3c4cf109c2534306105ed7d656bcfe5eacd0d2c0 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 17:10:02 +0100 Subject: [PATCH 04/45] terminations of a session --- src/mcp/server/streamableHttp.py | 62 ++++- tests/server/test_streamableHttp.py | 351 ++++++++++++++++++---------- 2 files changed, 281 insertions(+), 132 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 8b1498d7..0dd73e50 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -77,8 +77,8 @@ def __init__( ValueError: If the session ID contains invalid characters. """ if mcp_session_id is not None and ( - not SESSION_ID_PATTERN.match(mcp_session_id) or - SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None + not SESSION_ID_PATTERN.match(mcp_session_id) + or SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None ): raise ValueError( "Session ID must only contain visible ASCII characters (0x21-0x7E)" @@ -86,6 +86,7 @@ def __init__( self.mcp_session_id = mcp_session_id self._request_streams = {} + self._terminated = False def _create_error_response( self, @@ -126,6 +127,14 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No send: ASGI send function """ request = Request(scope, receive) + if self._terminated: + # If the session has been terminated, return 404 Not Found + response = self._create_error_response( + "Not Found: Session has been terminated", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return if request.method == "POST": await self._handle_post_request(scope, request, receive, send) @@ -192,7 +201,6 @@ async def _handle_post_request( raise ValueError( "No read stream writer available. Ensure connect() is called first." ) - return try: # Check Accept headers has_json, has_sse = self._check_accept_headers(request) @@ -417,7 +425,6 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate session ID if server has one if not await self._validate_session(request, send): return - # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) @@ -454,9 +461,46 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: ) await response(request.scope, request.receive, send) return + if not await self._validate_session(request, send): return - # TODO : Implement session termination logic + + # Terminate the session + self._terminate_session() + + # Return success response + response = self._create_error_response( + "Session terminated", + HTTPStatus.OK, + ) + await response(request.scope, request.receive, send) + + def _terminate_session(self) -> None: + """ + Terminate the current session, closing all streams and marking as terminated. + + Once terminated, all requests with this session ID will receive 404 Not Found. + """ + + self._terminated = True + logger.info(f"Terminating session: {self.mcp_session_id}") + + # We need a copy of the keys to avoid modification during iteration + request_stream_keys = list(self._request_streams.keys()) + + # Close all request streams (synchronously) + for key in request_stream_keys: + try: + # Get the stream + stream = self._request_streams.get(key) + if stream: + # We must use close() here, not aclose() since this is a sync method + stream.close() + except Exception as e: + logger.debug(f"Error closing stream {key} during termination: {e}") + + # Clear the request streams dictionary immediately + self._request_streams.clear() async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """ @@ -599,10 +643,16 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - # Clean up any remaining request streams for stream in list(self._request_streams.values()): try: await stream.aclose() except Exception: pass self._request_streams.clear() + # Clean up read/write streams + if self._read_stream_writer: + try: + await self._read_stream_writer.aclose() + except Exception: + pass + self._read_stream_writer = None diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index bf0128d1..eb7a5390 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -4,81 +4,87 @@ This file contains tests for request validation in the StreamableHTTP transport. """ +import multiprocessing import socket import time -from collections.abc import Generator -from multiprocessing import Process - +from collections.abc import AsyncGenerator, Generator +from http import HTTPStatus +from uuid import uuid4 +import contextlib import anyio import pytest import requests import uvicorn +from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response -from starlette.routing import Route +from starlette.routing import Mount, Route +from mcp.server import Server from mcp.server.streamableHttp import ( MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, StreamableHTTPServerTransport, ) -from mcp.types import JSONRPCMessage +from mcp.shared.exceptions import McpError +from mcp.types import ( + EmptyResult, + ErrorData, + JSONRPCMessage, + TextContent, + TextResourceContents, + Tool, +) # Test constants SERVER_NAME = "test_streamable_http_server" TEST_SESSION_ID = "test-session-id-12345" +INIT_REQUEST = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-03-26", + "capabilities": {}, + }, + "id": "init-1", +} + + +# Test server implementation that follows MCP protocol +class ServerTest(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + @self.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + elif uri.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return f"Slow response from {uri.host}" + + raise McpError( + error=ErrorData( + code=404, message="OOPS! no resource with that URI was found" + ) + ) + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] -# App handler class for testing validation (not a pytest test class) -class StreamableAppHandler: - def __init__(self, session_id=None): - self.transport = StreamableHTTPServerTransport(mcp_session_id=session_id) - self.started = False - self.read_stream = None - self.write_stream = None - - async def startup(self): - """Initialize the transport streams.""" - # Create real memory streams to satisfy type checking - read_stream_writer, read_stream = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception - ](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[ - JSONRPCMessage - ](0) - - # Assign the streams to the transport - self.transport._read_stream_writer = read_stream_writer - self.transport._write_stream_reader = write_stream_reader - - # Store the streams so they don't get garbage collected - self.read_stream = read_stream - self.write_stream = write_stream - - self.started = True - print("Transport streams initialized") - - async def handle_request(self, request: Request): - """Handle incoming requests by validating and responding.""" - # Make sure transport is initialized - if not self.started: - await self.startup() - - # Let the transport handle the request validation and response - try: - await self.transport.handle_request( - request.scope, request.receive, request._send - ) - except Exception as e: - print(f"Error handling request: {e}") - # Make sure we provide an error response - response = Response( - status_code=500, - content=f"Server error: {str(e)}", - media_type="text/plain", - ) - await response(request.scope, request.receive, request._send) + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + return [TextContent(type="text", text=f"Called {name}")] @pytest.fixture @@ -96,28 +102,93 @@ def server_url(server_port: int) -> str: def create_app(session_id=None) -> Starlette: - """Create a Starlette application for testing.""" - # Create our test app handler - app_handler = StreamableAppHandler(session_id=session_id) - - # Define a startup event to ensure the transport is initialized - async def on_startup(): - """Initialize the transport on application startup.""" - print("Initializing transport streams...") - await app_handler.startup() - app_handler.started = True - print("Transport initialized") + """Create a Starlette application for testing that matches the example server.""" + # Create server instance + server = ServerTest() + + # Store the server instances between requests for session management + server_instances = {} + # Lock to prevent race conditions when creating new sessions + session_creation_lock = anyio.Lock() + # Task group for running server instances + task_group = None + + @contextlib.asynccontextmanager + async def lifespan(app): + """Application lifespan context manager for managing task group.""" + nonlocal task_group + + async with anyio.create_task_group() as tg: + task_group = tg + print("Application started, task group initialized!") + try: + yield + finally: + print("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + print("Resources cleaned up successfully.") + + # ASGI handler for streamable HTTP connections + async def handle_streamable_http(scope, receive, send): + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Use existing transport if session ID matches + if ( + request_mcp_session_id is not None + and request_mcp_session_id in server_instances + ): + transport = server_instances[request_mcp_session_id] + print("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + elif session_id is None or request_mcp_session_id is None: + async with session_creation_lock: + # For tests with fixed session ID + new_session_id = session_id if session_id else uuid4().hex + + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + ) + + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + + if task_group is None: + response = Response( + "Internal Server Error: Task group is not initialized", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + return + + # Store the instance before starting the task to prevent races + server_instances[http_transport.mcp_session_id] = http_transport + task_group.start_soon(run_server) + + await http_transport.handle_request(scope, receive, send) + else: + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + # Create an ASGI application app = Starlette( - debug=True, # Enable debug mode for better error messages + debug=True, routes=[ - Route( - "/mcp", - endpoint=app_handler.handle_request, - methods=["GET", "POST", "DELETE"], - ), + Mount("/mcp", app=handle_streamable_http), ], - on_startup=[on_startup], + lifespan=lifespan, ) return app @@ -127,17 +198,15 @@ def run_server(port: int, session_id=None) -> None: """Run the test server.""" print(f"Starting test server on port {port} with session_id={session_id}") - # Create app with simpler configuration app = create_app(session_id) - - # Configure to use a single worker and simpler settings + # Configure server config = uvicorn.Config( app=app, host="127.0.0.1", port=port, - log_level="info", # Use info to see startup messages + log_level="info", limit_concurrency=10, - timeout_keep_alive=2, + timeout_keep_alive=5, access_log=False, ) @@ -161,7 +230,9 @@ def run_server(port: int, session_id=None) -> None: def basic_server(server_port: int) -> Generator[None, None, None]: """Start a basic server without session ID.""" # Start server process - process = Process(target=run_server, kwargs={"port": server_port}, daemon=True) + process = multiprocessing.Process( + target=run_server, kwargs={"port": server_port}, daemon=True + ) process.start() # Wait for server to start @@ -191,7 +262,7 @@ def basic_server(server_port: int) -> Generator[None, None, None]: def session_server(server_port: int) -> Generator[str, None, None]: """Start a server with session ID.""" # Start server process - process = Process( + process = multiprocessing.Process( target=run_server, kwargs={"port": server_port, "session_id": TEST_SESSION_ID}, daemon=True, @@ -309,17 +380,20 @@ def test_method_not_allowed(basic_server, server_url): def test_get_request_validation(basic_server, server_url): """Test GET request validation for SSE streams.""" - # Test GET without Accept header - response = requests.get(f"{server_url}/mcp") - assert response.status_code == 406 - assert "Not Acceptable" in response.text - # Test GET with wrong Accept header - response = requests.get( + response = requests.post( f"{server_url}/mcp", - headers={"Accept": "application/json"}, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, ) + # Test GET without Accept header + assert response.status_code == 200 + response = requests.get(f"{server_url}/mcp") assert response.status_code == 406 + assert "Not Acceptable" in response.text def test_session_validation(session_server, server_url): @@ -338,45 +412,6 @@ def test_session_validation(session_server, server_url): assert response.status_code == 400 assert "Missing session ID" in response.text - # Test with invalid session ID - response = requests.post( - f"{server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: "invalid-session-id", - }, - json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, - ) - assert response.status_code == 404 - assert "Invalid or expired session ID" in response.text - - -def test_delete_request(session_server, server_url): - """Test DELETE request for session termination.""" - # session_id not used directly in this test - - # Test without session ID - response = requests.delete(f"{server_url}/mcp") - assert response.status_code == 400 - assert "Missing session ID" in response.text - - # Test with invalid session ID - response = requests.delete( - f"{server_url}/mcp", - headers={MCP_SESSION_ID_HEADER: "invalid-session-id"}, - ) - assert response.status_code == 404 - assert "Invalid or expired session ID" in response.text - - -def test_delete_without_session_support(basic_server, server_url): - """Test DELETE request when server doesn't support sessions.""" - # Server without session support should reject DELETE - response = requests.delete(f"{server_url}/mcp") - assert response.status_code == 405 - assert "Method Not Allowed" in response.text - def test_session_id_pattern(): """Test that SESSION_ID_PATTERN correctly validates session IDs.""" @@ -414,7 +449,7 @@ def test_session_id_pattern(): def test_streamable_http_transport_init_validation(): - """Test that StreamableHTTPServerTransport validates session ID on initialization.""" + """Test that StreamableHTTPServerTransport validates session ID on init.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") assert valid_transport.mcp_session_id == "valid-id" @@ -434,3 +469,67 @@ def test_streamable_http_transport_init_validation(): with pytest.raises(ValueError): StreamableHTTPServerTransport(mcp_session_id="test\n") + + +def test_delete_request(session_server, server_url): + """Test DELETE request for session termination.""" + session_id = session_server + + # First, send an initialize request to properly initialize the server + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Test without session ID + response = requests.delete(f"{server_url}/mcp") + assert response.status_code == 400 + assert "Missing session ID" in response.text + + # Test valid session termination + response = requests.delete( + f"{server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: session_id}, + ) + # assert response.status_code == 200 + assert "Session terminated" in response.text + + +def test_session_termination(session_server, server_url): + """Test session termination via DELETE and subsequent request handling.""" + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = session_server + response = requests.delete( + f"{server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: session_id}, + ) + assert response.status_code == 200 + assert "Session terminated" in response.text + + # Try to use the terminated session + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text From bce74b3e148038a38324d730d5506aac055b6e05 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 19:25:21 +0100 Subject: [PATCH 05/45] fix cleaning up --- src/mcp/server/streamableHttp.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 0dd73e50..a8ee6f9b 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -649,10 +649,3 @@ async def message_router(): except Exception: pass self._request_streams.clear() - # Clean up read/write streams - if self._read_stream_writer: - try: - await self._read_stream_writer.aclose() - except Exception: - pass - self._read_stream_writer = None From 201157912d1d588283ed91fd05567b8b003ef891 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 20:51:46 +0100 Subject: [PATCH 06/45] add happy path test --- tests/server/test_streamableHttp.py | 64 +++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index eb7a5390..51b88c0c 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -4,13 +4,14 @@ This file contains tests for request validation in the StreamableHTTP transport. """ +import contextlib import multiprocessing import socket import time -from collections.abc import AsyncGenerator, Generator +from collections.abc import Generator from http import HTTPStatus from uuid import uuid4 -import contextlib + import anyio import pytest import requests @@ -19,7 +20,7 @@ from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response -from starlette.routing import Mount, Route +from starlette.routing import Mount from mcp.server import Server from mcp.server.streamableHttp import ( @@ -29,11 +30,8 @@ ) from mcp.shared.exceptions import McpError from mcp.types import ( - EmptyResult, ErrorData, - JSONRPCMessage, TextContent, - TextResourceContents, Tool, ) @@ -106,11 +104,9 @@ def create_app(session_id=None) -> Starlette: # Create server instance server = ServerTest() - # Store the server instances between requests for session management server_instances = {} # Lock to prevent race conditions when creating new sessions session_creation_lock = anyio.Lock() - # Task group for running server instances task_group = None @contextlib.asynccontextmanager @@ -130,7 +126,6 @@ async def lifespan(app): task_group = None print("Resources cleaned up successfully.") - # ASGI handler for streamable HTTP connections async def handle_streamable_http(scope, receive, send): request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) @@ -141,12 +136,11 @@ async def handle_streamable_http(scope, receive, send): and request_mcp_session_id in server_instances ): transport = server_instances[request_mcp_session_id] - print("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) - elif session_id is None or request_mcp_session_id is None: + elif request_mcp_session_id is None: async with session_creation_lock: - # For tests with fixed session ID - new_session_id = session_id if session_id else uuid4().hex + new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, @@ -156,11 +150,14 @@ async def handle_streamable_http(scope, receive, send): read_stream, write_stream = streams async def run_server(): - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) + try: + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + except Exception as e: + print(f"Server exception: {e}") if task_group is None: response = Response( @@ -533,3 +530,34 @@ def test_session_termination(session_server, server_url): ) assert response.status_code == 404 assert "Session has been terminated" in response.text + + +def test_response(basic_server, server_url): + """Test response handling for a valid request.""" + mcp_url = f"{server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + + # Try to use the terminated session + tools_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + stream=True, # Important for SSE + ) + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" From 2cebf087d6b5e965198b8a3bf57248cf8aa7aa31 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 21:58:21 +0100 Subject: [PATCH 07/45] tests --- tests/server/test_streamableHttp.py | 70 +++-------------------------- 1 file changed, 5 insertions(+), 65 deletions(-) diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 51b88c0c..b375fdc9 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -301,22 +301,6 @@ def test_accept_header_validation(basic_server, server_url): assert response.status_code == 406 assert "Not Acceptable" in response.text - # Test with only application/json - response = requests.post( - f"{server_url}/mcp", - headers={"Accept": "application/json", "Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - - # Test with only text/event-stream - response = requests.post( - f"{server_url}/mcp", - headers={"Accept": "text/event-stream", "Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - def test_content_type_validation(basic_server, server_url): """Test that Content-Type header is properly validated.""" @@ -347,6 +331,9 @@ def test_json_validation(basic_server, server_url): assert response.status_code == 400 assert "Parse error" in response.text + +def test_json_parsing(basic_server, server_url): + """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( f"{server_url}/mcp", @@ -375,24 +362,6 @@ def test_method_not_allowed(basic_server, server_url): assert "Method Not Allowed" in response.text -def test_get_request_validation(basic_server, server_url): - """Test GET request validation for SSE streams.""" - - response = requests.post( - f"{server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - # Test GET without Accept header - assert response.status_code == 200 - response = requests.get(f"{server_url}/mcp") - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - def test_session_validation(session_server, server_url): """Test session ID validation.""" # session_id not used directly in this test @@ -468,36 +437,7 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_delete_request(session_server, server_url): - """Test DELETE request for session termination.""" - session_id = session_server - - # First, send an initialize request to properly initialize the server - response = requests.post( - f"{server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - - # Test without session ID - response = requests.delete(f"{server_url}/mcp") - assert response.status_code == 400 - assert "Missing session ID" in response.text - - # Test valid session termination - response = requests.delete( - f"{server_url}/mcp", - headers={MCP_SESSION_ID_HEADER: session_id}, - ) - # assert response.status_code == 200 - assert "Session terminated" in response.text - - -def test_session_termination(session_server, server_url): +def test_session_termination(basic_server, server_url): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{server_url}/mcp", @@ -510,7 +450,7 @@ def test_session_termination(session_server, server_url): assert response.status_code == 200 # Now terminate the session - session_id = session_server + session_id = response.headers.get(MCP_SESSION_ID_HEADER) response = requests.delete( f"{server_url}/mcp", headers={MCP_SESSION_ID_HEADER: session_id}, From 6c9c320a38654c9145fe326d9e308e2893e8d9e3 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 22:19:26 +0100 Subject: [PATCH 08/45] json mode --- .../servers/simple-streamablehttp/README.md | 3 + .../mcp_simple_streamablehttp/server.py | 9 +- src/mcp/server/streamableHttp.py | 295 ++++++++++++------ tests/server/test_streamableHttp.py | 56 +++- 4 files changed, 258 insertions(+), 105 deletions(-) diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index aa5e707a..5125c3eb 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -20,6 +20,9 @@ uv run mcp-simple-streamablehttp --port 3000 # Custom logging level uv run mcp-simple-streamablehttp --log-level DEBUG + +# Enable JSON responses instead of SSE streams +uv run mcp-simple-streamablehttp --json-response ``` The server exposes a tool named "start-notification-stream" that accepts three arguments: diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 3dc972b7..c39a3720 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -48,9 +48,16 @@ async def lifespan(app): default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", ) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) def main( port: int, log_level: str, + json_response: bool, ) -> int: # Configure logging logging.basicConfig( @@ -145,7 +152,7 @@ async def handle_streamable_http(scope, receive, send): async with session_creation_lock: new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, + mcp_session_id=new_session_id, is_json_response_enabled=json_response ) async with http_transport.connect() as streams: read_stream, write_stream = streams diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index a8ee6f9b..b6ef396c 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -54,6 +54,7 @@ class StreamableHTTPServerTransport: Handles POST requests containing JSON-RPC messages and provides Server-Sent Events (SSE) responses for streaming communication. + When configured, can also return JSON responses instead of SSE streams. """ # Server notification streams for POST requests as well as standalone SSE stream @@ -65,6 +66,7 @@ class StreamableHTTPServerTransport: def __init__( self, mcp_session_id: str | None, + is_json_response_enabled: bool = False, ): """ Initialize a new StreamableHTTP server transport. @@ -72,6 +74,8 @@ def __init__( Args: mcp_session_id: Optional session identifier for this connection. Must contain only visible ASCII characters (0x21-0x7E). + is_json_response_enabled: If True, return JSON responses for requests + instead of SSE streams. Default is False. Raises: ValueError: If the session ID contains invalid characters. @@ -85,6 +89,7 @@ def __init__( ) self.mcp_session_id = mcp_session_id + self.is_json_response_enabled = is_json_response_enabled self._request_streams = {} self._terminated = False @@ -110,6 +115,36 @@ def _create_error_response( headers=response_headers, ) + def _create_json_response( + self, + response_message: JSONRPCMessage, + status_code: HTTPStatus = HTTPStatus.OK, + headers: dict[str, str] | None = None, + ) -> Response: + """ + Create a JSON response from a JSONRPCMessage. + + Args: + response_message: The JSON-RPC message to include in the response + status_code: HTTP status code (default: 200 OK) + headers: Additional headers to include + + Returns: + A Starlette Response object with the JSON-RPC message + """ + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + response_message.model_dump_json(by_alias=True, exclude_none=True), + status_code=status_code, + headers=response_headers, + ) + def _get_session_id(self, request: Request) -> str | None: """ Extract the session ID from request headers. @@ -303,105 +338,183 @@ async def _handle_post_request( return - # For requests, set up an SSE stream for the response + # For requests, determine whether to return JSON or set up SSE stream if is_request: - # Set up headers - headers = { - "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", - "Content-Type": CONTENT_TYPE_SSE, - } + if self.is_json_response_enabled: + # JSON response mode - create a response future + request_id = None + if isinstance(message.root, JSONRPCRequest): + request_id = str(message.root.id) + + if not request_id: + # Should not happen for valid JSONRPCRequest, but handle just in case + response = self._create_error_response( + "Invalid Request: Missing request ID", + HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + return - if self.mcp_session_id: - headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - # Create SSE stream - sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, Any]](0) - ) + # Create promise stream for getting response + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) - async def sse_writer(): - try: - # Create a request-specific message stream for this POST request - request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) - ) + # Register this stream for the request ID + self._request_streams[request_id] = request_stream_writer - # Get the request ID from the incoming request message - request_id = None - if isinstance(message.root, JSONRPCRequest): - request_id = str(message.root.id) - # Register this stream for the request ID - if request_id: - self._request_streams[request_id] = ( - request_stream_writer - ) + # Process the message + await writer.send(message) - async with sse_stream_writer, request_stream_reader: - # Process messages from the request-specific stream - async for received_message in request_stream_reader: - # Send the message via SSE - related_request_id = None - - if isinstance( - received_message.root, JSONRPCNotification - ): - # Get related_request_id from params - params = received_message.root.params - if params and "related_request_id" in params: - related_request_id = params.get( - "related_request_id" - ) - logger.debug( - f"NOTIFICATION: {related_request_id}, " - f"{params.get('data')}" - ) - - # Build the event data - event_data = { - "event": "message", - "data": received_message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - - await sse_stream_writer.send(event_data) - - # If response, remove from pending streams and close - if isinstance(received_message.root, JSONRPCResponse): - if request_id: - self._request_streams.pop(request_id, None) - break + try: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + response_message = None + + # Use similar approach to SSE writer for consistency + async for received_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance(received_message.root, JSONRPCResponse): + response_message = received_message + break + # For notifications, we need to keep waiting for the actual response + elif isinstance(received_message.root, JSONRPCNotification): + # Just process it and continue waiting + logger.debug( + f"Received notification while waiting for response: {received_message.root.method}" + ) + continue + + # At this point we should have a response + if response_message: + # Create JSON response + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + # This shouldn't happen in normal operation + logger.error("No response message received before stream closed") + response = self._create_error_response( + "Error processing request: No response received", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) except Exception as e: - logger.exception(f"Error in SSE writer: {e}") + logger.exception(f"Error processing JSON response: {e}") + response = self._create_error_response( + f"Error processing request: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) finally: - logger.debug("Closing SSE writer") - # TODO - - # Create and start EventSourceResponse - response = EventSourceResponse( - content=sse_stream_reader, - data_sender_callable=sse_writer, - headers=headers, - ) - - # Extract the request ID outside the try block for proper scope - outer_request_id = None - if isinstance(message.root, JSONRPCRequest): - outer_request_id = str(message.root.id) + # Clean up the request stream + if request_id in self._request_streams: + self._request_streams.pop(request_id, None) + await request_stream_reader.aclose() + await request_stream_writer.aclose() + else: + # SSE stream mode (original behavior) + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, Any]](0) + ) + + async def sse_writer(): + try: + # Create a request-specific message stream for this POST request + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Get the request ID from the incoming request message + request_id = None + if isinstance(message.root, JSONRPCRequest): + request_id = str(message.root.id) + # Register this stream for the request ID + if request_id: + self._request_streams[request_id] = ( + request_stream_writer + ) + + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for received_message in request_stream_reader: + # Send the message via SSE + related_request_id = None + + if isinstance( + received_message.root, JSONRPCNotification + ): + # Get related_request_id from params + params = received_message.root.params + if params and "related_request_id" in params: + related_request_id = params.get( + "related_request_id" + ) + logger.debug( + f"NOTIFICATION: {related_request_id}, " + f"{params.get('data')}" + ) + + # Build the event data + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance( + received_message.root, JSONRPCResponse + ): + if request_id: + self._request_streams.pop(request_id, None) + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + # TODO + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Extract the request ID outside the try block for proper scope + outer_request_id = None + if isinstance(message.root, JSONRPCRequest): + outer_request_id = str(message.root.id) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) - # Start the SSE response (this will send headers immediately) - try: - # First send the response to establish the SSE connection - async with anyio.create_task_group() as tg: - tg.start_soon(response, scope, receive, send) - - # Then send the message to be processed by the server - await writer.send(message) - except Exception: - logger.exception("SSE response error") - # Make sure to clean up the request stream if something goes wrong - if outer_request_id and outer_request_id in self._request_streams: - self._request_streams.pop(outer_request_id, None) + # Then send the message to be processed by the server + await writer.send(message) + except Exception: + logger.exception("SSE response error") + # Make sure to clean up the request stream if something goes wrong + if ( + outer_request_id + and outer_request_id in self._request_streams + ): + self._request_streams.pop(outer_request_id, None) except Exception as err: logger.exception("Error handling POST request") diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index b375fdc9..42c416c5 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -99,8 +99,13 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" -def create_app(session_id=None) -> Starlette: - """Create a Starlette application for testing that matches the example server.""" +def create_app(session_id=None, is_json_response_enabled=False) -> Starlette: + """Create a Starlette application for testing that matches the example server. + + Args: + session_id: Optional session ID to use for the server. + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + """ # Create server instance server = ServerTest() @@ -144,6 +149,7 @@ async def handle_streamable_http(scope, receive, send): http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, + is_json_response_enabled=is_json_response_enabled, ) async with http_transport.connect() as streams: @@ -191,11 +197,20 @@ async def run_server(): return app -def run_server(port: int, session_id=None) -> None: - """Run the test server.""" - print(f"Starting test server on port {port} with session_id={session_id}") +def run_server(port: int, session_id=None, is_json_response_enabled=False) -> None: + """Run the test server. + + Args: + port: Port to listen on. + session_id: Optional session ID to use for the server. + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + """ + print( + f"Starting test server on port {port} with " + f"session_id={session_id}, json_enabled={is_json_response_enabled}" + ) - app = create_app(session_id) + app = create_app(session_id, is_json_response_enabled) # Configure server config = uvicorn.Config( app=app, @@ -256,12 +271,12 @@ def basic_server(server_port: int) -> Generator[None, None, None]: @pytest.fixture -def session_server(server_port: int) -> Generator[str, None, None]: - """Start a server with session ID.""" - # Start server process +def json_response_server(server_port: int) -> Generator[None, None, None]: + """Start a server with JSON response enabled.""" + # Start server process with is_json_response_enabled=True process = multiprocessing.Process( target=run_server, - kwargs={"port": server_port, "session_id": TEST_SESSION_ID}, + kwargs={"port": server_port, "is_json_response_enabled": True}, daemon=True, ) process.start() @@ -280,7 +295,7 @@ def session_server(server_port: int) -> Generator[str, None, None]: else: raise RuntimeError(f"Server failed to start after {max_attempts} attempts") - yield TEST_SESSION_ID + yield # Clean up process.terminate() @@ -362,7 +377,7 @@ def test_method_not_allowed(basic_server, server_url): assert "Method Not Allowed" in response.text -def test_session_validation(session_server, server_url): +def test_session_validation(basic_server, server_url): """Test session ID validation.""" # session_id not used directly in this test @@ -497,7 +512,22 @@ def test_response(basic_server, server_url): MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier }, json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, - stream=True, # Important for SSE + stream=True, ) assert tools_response.status_code == 200 assert tools_response.headers.get("Content-Type") == "text/event-stream" + + +def test_json_response(json_response_server, server_url): + """Test response handling when is_json_response_enabled is True.""" + mcp_url = f"{server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" From ede8cde91c938db4a64bcccb78387bc79e713d86 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 10:40:52 +0100 Subject: [PATCH 09/45] clean up --- .../mcp_simple_streamablehttp/__main__.py | 2 +- .../mcp_simple_streamablehttp/server.py | 3 +- src/mcp/server/streamableHttp.py | 192 +++++------------- uv.lock | 1 - 4 files changed, 58 insertions(+), 140 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py index a6876bf9..f5f6e402 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py @@ -1,4 +1,4 @@ from .server import main if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index c39a3720..88249baf 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -152,7 +152,8 @@ async def handle_streamable_http(scope, receive, send): async with session_creation_lock: new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, is_json_response_enabled=json_response + mcp_session_id=new_session_id, + is_json_response_enabled=json_response, ) async with http_transport.connect() as streams: read_stream, write_stream = streams diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index b6ef396c..2bc528b0 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -52,14 +52,15 @@ class StreamableHTTPServerTransport: """ HTTP server transport with event streaming support for MCP. - Handles POST requests containing JSON-RPC messages and provides - Server-Sent Events (SSE) responses for streaming communication. - When configured, can also return JSON responses instead of SSE streams. + Handles JSON-RPC messages in HTTP POST requests with SSE streaming. + Supports optional JSON responses and session management. """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None - _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = ( + None + ) + _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None # Dictionary to track request-specific message streams _request_streams: dict[str, MemoryObjectSendStream[JSONRPCMessage]] @@ -67,7 +68,7 @@ def __init__( self, mcp_session_id: str | None, is_json_response_enabled: bool = False, - ): + ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -80,9 +81,8 @@ def __init__( Raises: ValueError: If the session ID contains invalid characters. """ - if mcp_session_id is not None and ( - not SESSION_ID_PATTERN.match(mcp_session_id) - or SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None + if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch( + mcp_session_id ): raise ValueError( "Session ID must only contain visible ASCII characters (0x21-0x7E)" @@ -93,15 +93,13 @@ def __init__( self._request_streams = {} self._terminated = False - def _create_error_response( + def _create_server_response( self, message: str, status_code: HTTPStatus, headers: dict[str, str] | None = None, ) -> Response: - """ - Create a standardized error response. - """ + """Create a standardized server response.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: response_headers.update(headers) @@ -121,17 +119,7 @@ def _create_json_response( status_code: HTTPStatus = HTTPStatus.OK, headers: dict[str, str] | None = None, ) -> Response: - """ - Create a JSON response from a JSONRPCMessage. - - Args: - response_message: The JSON-RPC message to include in the response - status_code: HTTP status code (default: 200 OK) - headers: Additional headers to include - - Returns: - A Starlette Response object with the JSON-RPC message - """ + """Create a JSON response from a JSONRPCMessage""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: response_headers.update(headers) @@ -146,25 +134,15 @@ def _create_json_response( ) def _get_session_id(self, request: Request) -> str | None: - """ - Extract the session ID from request headers. - """ + """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """ - ASGI application entry point that handles all HTTP requests - - Args: - stream_id: Unique identifier for this stream - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ + """Application entry point that handles all HTTP requests""" request = Request(scope, receive) if self._terminated: # If the session has been terminated, return 404 Not Found - response = self._create_error_response( + response = self._create_server_response( "Not Found: Session has been terminated", HTTPStatus.NOT_FOUND, ) @@ -181,15 +159,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """ - Check if the request accepts the required media types. - - Args: - request: The HTTP request - - Returns: - Tuple of (has_json, has_sse) indicating whether each media type is accepted - """ + """Check if the request accepts the required media types.""" accept_header = request.headers.get("accept", "") accept_types = [media_type.strip() for media_type in accept_header.split(",")] @@ -203,15 +173,7 @@ def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: return has_json, has_sse def _check_content_type(self, request: Request) -> bool: - """ - Check if the request has the correct Content-Type. - - Args: - request: The HTTP request - - Returns: - True if Content-Type is acceptable, False otherwise - """ + """Check if the request has the correct Content-Type.""" content_type = request.headers.get("content-type", "") content_type_parts = [ part.strip() for part in content_type.split(";")[0].split(",") @@ -222,15 +184,7 @@ def _check_content_type(self, request: Request) -> bool: async def _handle_post_request( self, scope: Scope, request: Request, receive: Receive, send: Send ) -> None: - """ - Handles POST requests containing JSON-RPC messages - - Args: - scope: ASGI scope - request: Starlette Request object - receive: ASGI receive function - send: ASGI send function - """ + """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer if writer is None: raise ValueError( @@ -240,7 +194,7 @@ async def _handle_post_request( # Check Accept headers has_json, has_sse = self._check_accept_headers(request) if not (has_json and has_sse): - response = self._create_error_response( + response = self._create_server_response( ( "Not Acceptable: Client must accept both application/json and " "text/event-stream" @@ -252,7 +206,7 @@ async def _handle_post_request( # Validate Content-Type if not self._check_content_type(request): - response = self._create_error_response( + response = self._create_server_response( "Unsupported Media Type: Content-Type must be application/json", HTTPStatus.UNSUPPORTED_MEDIA_TYPE, ) @@ -262,7 +216,7 @@ async def _handle_post_request( # Parse the body - only read it once body = await request.body() if len(body) > MAXIMUM_MESSAGE_SIZE: - response = self._create_error_response( + response = self._create_server_response( "Payload Too Large: Message exceeds maximum size", HTTPStatus.REQUEST_ENTITY_TOO_LARGE, ) @@ -272,32 +226,23 @@ async def _handle_post_request( try: raw_message = json.loads(body) except json.JSONDecodeError as e: - response = self._create_error_response( + response = self._create_server_response( f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return - message = None try: message = JSONRPCMessage.model_validate(raw_message) except ValidationError as e: - response = self._create_error_response( + response = self._create_server_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return - if not message: - response = self._create_error_response( - "Invalid Request: Message is empty", - HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) - return - # Check if this is an initialization request is_initialization_request = ( isinstance(message.root, JSONRPCRequest) @@ -312,7 +257,7 @@ async def _handle_post_request( # If request has a session ID but doesn't match, return 404 if request_session_id and request_session_id != self.mcp_session_id: - response = self._create_error_response( + response = self._create_server_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, ) @@ -327,7 +272,7 @@ async def _handle_post_request( # For notifications and responses only, return 202 Accepted if not is_request: # Create response object and send it - response = self._create_error_response( + response = self._create_server_response( "Accepted", HTTPStatus.ACCEPTED, ) @@ -347,8 +292,8 @@ async def _handle_post_request( request_id = str(message.root.id) if not request_id: - # Should not happen for valid JSONRPCRequest, but handle just in case - response = self._create_error_response( + # Should not happen for valid JSONRPCRequest, but handle it + response = self._create_server_response( "Invalid Request: Missing request ID", HTTPStatus.BAD_REQUEST, ) @@ -370,20 +315,19 @@ async def _handle_post_request( # Process messages from the request-specific stream # We need to collect all messages until we get a response response_message = None - + # Use similar approach to SSE writer for consistency async for received_message in request_stream_reader: # If it's a response, this is what we're waiting for if isinstance(received_message.root, JSONRPCResponse): response_message = received_message break - # For notifications, we need to keep waiting for the actual response + # For notifications, keep waiting for the actual response elif isinstance(received_message.root, JSONRPCNotification): # Just process it and continue waiting logger.debug( - f"Received notification while waiting for response: {received_message.root.method}" + f"Notification: {received_message.root.method}" ) - continue # At this point we should have a response if response_message: @@ -392,15 +336,17 @@ async def _handle_post_request( await response(scope, receive, send) else: # This shouldn't happen in normal operation - logger.error("No response message received before stream closed") - response = self._create_error_response( + logger.error( + "No response message received before stream closed" + ) + response = self._create_server_response( "Error processing request: No response received", HTTPStatus.INTERNAL_SERVER_ERROR, ) await response(scope, receive, send) except Exception as e: logger.exception(f"Error processing JSON response: {e}") - response = self._create_error_response( + response = self._create_server_response( f"Error processing request: {str(e)}", HTTPStatus.INTERNAL_SERVER_ERROR, ) @@ -428,14 +374,14 @@ async def _handle_post_request( ) async def sse_writer(): + # Get the request ID from the incoming request message + request_id = None try: - # Create a request-specific message stream for this POST request + # Create a request-specific message stream for this POST request_stream_writer, request_stream_reader = ( anyio.create_memory_object_stream[JSONRPCMessage](0) ) - # Get the request ID from the incoming request message - request_id = None if isinstance(message.root, JSONRPCRequest): request_id = str(message.root.id) # Register this stream for the request ID @@ -485,7 +431,9 @@ async def sse_writer(): logger.exception(f"Error in SSE writer: {e}") finally: logger.debug("Closing SSE writer") - # TODO + # Clean up the request-specific streams + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) # Create and start EventSourceResponse response = EventSourceResponse( @@ -509,7 +457,7 @@ async def sse_writer(): await writer.send(message) except Exception: logger.exception("SSE response error") - # Make sure to clean up the request stream if something goes wrong + # Clean up the request stream if something goes wrong if ( outer_request_id and outer_request_id in self._request_streams @@ -518,7 +466,7 @@ async def sse_writer(): except Exception as err: logger.exception("Error handling POST request") - response = self._create_error_response( + response = self._create_server_response( f"Error handling POST request: {err}", HTTPStatus.INTERNAL_SERVER_ERROR, ) @@ -528,13 +476,7 @@ async def sse_writer(): return async def _handle_get_request(self, request: Request, send: Send) -> None: - """ - Handle GET requests for SSE stream establishment - - Args: - request: The HTTP request - send: ASGI send function - """ + """Handle GET requests for SSE stream establishment.""" # Validate session ID if server has one if not await self._validate_session(request, send): return @@ -542,7 +484,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: _, has_sse = self._check_accept_headers(request) if not has_sse: - response = self._create_error_response( + response = self._create_server_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, ) @@ -551,24 +493,18 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # TODO: Implement SSE stream for GET requests # For now, return 501 Not Implemented - response = self._create_error_response( + response = self._create_server_response( "SSE stream from GET request not implemented yet", HTTPStatus.NOT_IMPLEMENTED, ) await response(request.scope, request.receive, send) async def _handle_delete_request(self, request: Request, send: Send) -> None: - """ - Handle DELETE requests for explicit session termination - - Args: - request: The HTTP request - send: ASGI send function - """ + """Handle DELETE requests for explicit session termination.""" # Validate session ID if not self.mcp_session_id: # If no session ID set, return Method Not Allowed - response = self._create_error_response( + response = self._create_server_response( "Method Not Allowed: Session termination not supported", HTTPStatus.METHOD_NOT_ALLOWED, ) @@ -581,16 +517,14 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: # Terminate the session self._terminate_session() - # Return success response - response = self._create_error_response( + response = self._create_server_response( "Session terminated", HTTPStatus.OK, ) await response(request.scope, request.receive, send) def _terminate_session(self) -> None: - """ - Terminate the current session, closing all streams and marking as terminated. + """Terminate the current session, closing all streams. Once terminated, all requests with this session ID will receive 404 Not Found. """ @@ -616,13 +550,7 @@ def _terminate_session(self) -> None: self._request_streams.clear() async def _handle_unsupported_request(self, request: Request, send: Send) -> None: - """ - Handle unsupported HTTP methods - - Args: - request: The HTTP request - send: ASGI send function - """ + """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", @@ -638,16 +566,7 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non await response(request.scope, request.receive, send) async def _validate_session(self, request: Request, send: Send) -> bool: - """ - Validate the session ID in the request. - - Args: - request: The HTTP request - send: ASGI send function - - Returns: - bool: True if session is valid, False otherwise - """ + """Validate the session ID in the request.""" if not self.mcp_session_id: # If we're not using session IDs, return True return True @@ -657,7 +576,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # If no session ID provided but required, return error if not request_session_id: - response = self._create_error_response( + response = self._create_server_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, ) @@ -666,7 +585,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # If session ID doesn't match, return error if request_session_id != self.mcp_session_id: - response = self._create_error_response( + response = self._create_server_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, ) @@ -685,8 +604,7 @@ async def connect( ], None, ]: - """ - Context manager that provides read and write streams for a connection + """Context manager that provides read and write streams for a connection. Yields: Tuple of (read_stream, write_stream) for bidirectional communication diff --git a/uv.lock b/uv.lock index 65439e5c..6618ea36 100644 --- a/uv.lock +++ b/uv.lock @@ -487,7 +487,6 @@ wheels = [ [[package]] name = "mcp" -version = "1.6.1.dev12+70115b9" source = { editable = "." } dependencies = [ { name = "anyio" }, From 2a3bed8e50d19a572ad3fe9a82e2ad347d2d1c57 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 11:01:51 +0100 Subject: [PATCH 10/45] fix example server --- .../mcp_simple_streamablehttp/server.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 88249baf..eec5edb4 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -155,25 +155,24 @@ async def handle_streamable_http(scope, receive, send): mcp_session_id=new_session_id, is_json_response_enabled=json_response, ) - async with http_transport.connect() as streams: - read_stream, write_stream = streams + server_instances[http_transport.mcp_session_id] = http_transport + async with http_transport.connect() as streams: + read_stream, write_stream = streams - async def run_server(): - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - ) + async def run_server(): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) - if not task_group: - raise RuntimeError("Task group is not initialized") + if not task_group: + raise RuntimeError("Task group is not initialized") - # Store the instance before starting the task to prevent races - server_instances[http_transport.mcp_session_id] = http_transport - task_group.start_soon(run_server) + task_group.start_soon(run_server) - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) else: response = Response( "Bad Request: No valid session ID provided", From 0456b1bd1c8f5dd2eaf6650ac8b23413c6614322 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 11:07:15 +0100 Subject: [PATCH 11/45] return 405 for get stream --- .../servers/simple-streamablehttp/pyproject.toml | 13 +------------ src/mcp/server/streamableHttp.py | 5 ++--- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml index de43bd2f..c35887d1 100644 --- a/examples/servers/simple-streamablehttp/pyproject.toml +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -1,23 +1,12 @@ [project] name = "mcp-simple-streamablehttp" version = "0.1.0" -description = "A simple MCP server exposing a website fetching tool with StreamableHttp transport" +description = "A simple MCP server exposing a StreamableHttp transport for testing" readme = "README.md" requires-python = ">=3.10" authors = [{ name = "Anthropic, PBC." }] -maintainers = [ - { name = "David Soria Parra", email = "davidsp@anthropic.com" }, - { name = "Justin Spahr-Summers", email = "justin@anthropic.com" }, -] keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"] license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] [project.scripts] diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 2bc528b0..09b94395 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -492,10 +492,10 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # TODO: Implement SSE stream for GET requests - # For now, return 501 Not Implemented + # For now, return 405 Method Not Allowed response = self._create_server_response( "SSE stream from GET request not implemented yet", - HTTPStatus.NOT_IMPLEMENTED, + HTTPStatus.METHOD_NOT_ALLOWED, ) await response(request.scope, request.receive, send) @@ -514,7 +514,6 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: if not await self._validate_session(request, send): return - # Terminate the session self._terminate_session() response = self._create_server_response( From 97ca48dc2dd00ca56e99b955c04151084c1d3801 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 11:59:46 +0100 Subject: [PATCH 12/45] speed up tests --- tests/server/test_streamableHttp.py | 139 +++++++++++++++------------- 1 file changed, 75 insertions(+), 64 deletions(-) diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 42c416c5..063ad82b 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -85,25 +85,10 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] -@pytest.fixture -def server_port() -> int: - """Find an available port for the test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - """Get the URL for the test server.""" - return f"http://127.0.0.1:{server_port}" - - -def create_app(session_id=None, is_json_response_enabled=False) -> Starlette: +def create_app(is_json_response_enabled=False) -> Starlette: """Create a Starlette application for testing that matches the example server. Args: - session_id: Optional session ID to use for the server. is_json_response_enabled: If True, use JSON responses instead of SSE streams. """ # Create server instance @@ -197,20 +182,19 @@ async def run_server(): return app -def run_server(port: int, session_id=None, is_json_response_enabled=False) -> None: +def run_server(port: int, is_json_response_enabled=False) -> None: """Run the test server. Args: port: Port to listen on. - session_id: Optional session ID to use for the server. is_json_response_enabled: If True, use JSON responses instead of SSE streams. """ print( f"Starting test server on port {port} with " - f"session_id={session_id}, json_enabled={is_json_response_enabled}" + f"json_enabled={is_json_response_enabled}" ) - app = create_app(session_id, is_json_response_enabled) + app = create_app(is_json_response_enabled) # Configure server config = uvicorn.Config( app=app, @@ -238,22 +222,38 @@ def run_server(port: int, session_id=None, is_json_response_enabled=False) -> No print("Server shutdown") +# Test fixtures - using same approach as SSE tests @pytest.fixture -def basic_server(server_port: int) -> Generator[None, None, None]: - """Start a basic server without session ID.""" - # Start server process - process = multiprocessing.Process( - target=run_server, kwargs={"port": server_port}, daemon=True +def basic_server_port() -> int: + """Find an available port for the basic server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def json_server_port() -> int: + """Find an available port for the JSON response server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def basic_server(basic_server_port: int) -> Generator[None, None, None]: + """Start a basic server.""" + proc = multiprocessing.Process( + target=run_server, kwargs={"port": basic_server_port}, daemon=True ) - process.start() + proc.start() - # Wait for server to start + # Wait for server to be running max_attempts = 20 attempt = 0 while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) + s.connect(("127.0.0.1", basic_server_port)) break except ConnectionRefusedError: time.sleep(0.1) @@ -264,30 +264,29 @@ def basic_server(server_port: int) -> Generator[None, None, None]: yield # Clean up - process.terminate() - process.join(timeout=1) - if process.is_alive(): - process.kill() + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") @pytest.fixture -def json_response_server(server_port: int) -> Generator[None, None, None]: +def json_response_server(json_server_port: int) -> Generator[None, None, None]: """Start a server with JSON response enabled.""" - # Start server process with is_json_response_enabled=True - process = multiprocessing.Process( + proc = multiprocessing.Process( target=run_server, - kwargs={"port": server_port, "is_json_response_enabled": True}, + kwargs={"port": json_server_port, "is_json_response_enabled": True}, daemon=True, ) - process.start() + proc.start() - # Wait for server to start + # Wait for server to be running max_attempts = 20 attempt = 0 while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) + s.connect(("127.0.0.1", json_server_port)) break except ConnectionRefusedError: time.sleep(0.1) @@ -298,18 +297,30 @@ def json_response_server(server_port: int) -> Generator[None, None, None]: yield # Clean up - process.terminate() - process.join(timeout=1) - if process.is_alive(): - process.kill() + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture +def basic_server_url(basic_server_port: int) -> str: + """Get the URL for the basic test server.""" + return f"http://127.0.0.1:{basic_server_port}" + + +@pytest.fixture +def json_server_url(json_server_port: int) -> str: + """Get the URL for the JSON response test server.""" + return f"http://127.0.0.1:{json_server_port}" # Basic request validation tests -def test_accept_header_validation(basic_server, server_url): +def test_accept_header_validation(basic_server, basic_server_url): """Test that Accept header is properly validated.""" # Test without Accept header response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={"Content-Type": "application/json"}, json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, ) @@ -317,11 +328,11 @@ def test_accept_header_validation(basic_server, server_url): assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server, server_url): +def test_content_type_validation(basic_server, basic_server_url): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "text/plain", @@ -332,11 +343,11 @@ def test_content_type_validation(basic_server, server_url): assert "Unsupported Media Type" in response.text -def test_json_validation(basic_server, server_url): +def test_json_validation(basic_server, basic_server_url): """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -347,11 +358,11 @@ def test_json_validation(basic_server, server_url): assert "Parse error" in response.text -def test_json_parsing(basic_server, server_url): +def test_json_parsing(basic_server, basic_server_url): """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -362,11 +373,11 @@ def test_json_parsing(basic_server, server_url): assert "Validation error" in response.text -def test_method_not_allowed(basic_server, server_url): +def test_method_not_allowed(basic_server, basic_server_url): """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -377,13 +388,13 @@ def test_method_not_allowed(basic_server, server_url): assert "Method Not Allowed" in response.text -def test_session_validation(basic_server, server_url): +def test_session_validation(basic_server, basic_server_url): """Test session ID validation.""" # session_id not used directly in this test # Test without session ID response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -452,10 +463,10 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server, server_url): +def test_session_termination(basic_server, basic_server_url): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -467,7 +478,7 @@ def test_session_termination(basic_server, server_url): # Now terminate the session session_id = response.headers.get(MCP_SESSION_ID_HEADER) response = requests.delete( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={MCP_SESSION_ID_HEADER: session_id}, ) assert response.status_code == 200 @@ -475,7 +486,7 @@ def test_session_termination(basic_server, server_url): # Try to use the terminated session response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -487,9 +498,9 @@ def test_session_termination(basic_server, server_url): assert "Session has been terminated" in response.text -def test_response(basic_server, server_url): +def test_response(basic_server, basic_server_url): """Test response handling for a valid request.""" - mcp_url = f"{server_url}/mcp" + mcp_url = f"{basic_server_url}/mcp" response = requests.post( mcp_url, headers={ @@ -518,9 +529,9 @@ def test_response(basic_server, server_url): assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server, server_url): +def test_json_response(json_response_server, json_server_url): """Test response handling when is_json_response_enabled is True.""" - mcp_url = f"{server_url}/mcp" + mcp_url = f"{json_server_url}/mcp" response = requests.post( mcp_url, headers={ @@ -530,4 +541,4 @@ def test_json_response(json_response_server, server_url): json=INIT_REQUEST, ) assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" + assert response.headers.get("Content-Type") == "application/json" \ No newline at end of file From f738cbfca5cb9bb94aabe8e762b71a736e21c33e Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 17:04:49 +0100 Subject: [PATCH 13/45] stateless implemetation --- .../simple-streamablehttp-stateless/README.md | 62 +++++++ .../__init__.py | 0 .../__main__.py | 4 + .../server.py | 171 ++++++++++++++++++ .../pyproject.toml | 36 ++++ src/mcp/server/lowlevel/server.py | 11 +- src/mcp/server/session.py | 7 +- tests/server/test_streamableHttp.py | 2 +- uv.lock | 39 ++++ 9 files changed, 329 insertions(+), 3 deletions(-) create mode 100644 examples/servers/simple-streamablehttp-stateless/README.md create mode 100644 examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__init__.py create mode 100644 examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py create mode 100644 examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py create mode 100644 examples/servers/simple-streamablehttp-stateless/pyproject.toml diff --git a/examples/servers/simple-streamablehttp-stateless/README.md b/examples/servers/simple-streamablehttp-stateless/README.md new file mode 100644 index 00000000..e282f4e4 --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/README.md @@ -0,0 +1,62 @@ +# MCP Simple StreamableHttp Stateless Server Example + +A stateless MCP server example demonstrating the StreamableHttp transport without maintaining session state. This example is ideal for understanding how to deploy MCP servers in multi-node environments where requests can be routed to any instance. + +## Features + +- Uses the StreamableHTTP transport in stateless mode (mcp_session_id=None) +- Each request creates a new ephemeral connection +- No session state maintained between requests +- Task lifecycle scoped to individual requests +- Suitable for deployment in multi-node environments + +## Key Differences from Stateful Version + +1. **No Session Management**: The server explicitly sets `mcp_session_id=None` when creating the transport +2. **Request Scoped**: Each request creates its own server instance and task group +3. **Immediate Cleanup**: Resources are cleaned up after each request completes +4. **Rejcts Session IDs**: If a client sends a session ID, the server rejects it with a BAD_REQUEST +5. **Stateless Deployments**: Can be deployed to multiple nodes behind a load balancer + +## Usage + +Start the server: + +```bash +# Using default port 3000 +uv run mcp-simple-streamablehttp-stateless + +# Using custom port +uv run mcp-simple-streamablehttp-stateless --port 3000 + +# Custom logging level +uv run mcp-simple-streamablehttp-stateless --log-level DEBUG + +# Enable JSON responses instead of SSE streams +uv run mcp-simple-streamablehttp-stateless --json-response +``` + +The server exposes a tool named "start-notification-stream" that accepts three arguments: + +- `interval`: Time between notifications in seconds (e.g., 1.0) +- `count`: Number of notifications to send (e.g., 5) +- `caller`: Identifier string for the caller + +## Client Considerations + +When connecting to a stateless server: +1. Do not send the `X-MCP-Session-ID` header +2. Each request is independent with no shared state +3. Suitable for one-shot operations or when state can be maintained client-side +4. Works well with load balancers that distribute requests across multiple instances + +## Deployment Benefits + +1. **Horizontal Scaling**: Deploy multiple instances behind a load balancer +2. **No Session Affinity**: Requests can be routed to any instance +3. **Simplified Infrastructure**: No session storage or sticky sessions required +4. **Cloud Native**: Works well in containerized and serverless environments + +## Client + +You can connect to this server using an HTTP client. For now, only the TypeScript SDK has streamable HTTP client examples, or you can use [Inspector](https://github.com/modelcontextprotocol/inspector) for testing. \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__init__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py new file mode 100644 index 00000000..f5f6e402 --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py @@ -0,0 +1,4 @@ +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py new file mode 100644 index 00000000..fe22e7c3 --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -0,0 +1,171 @@ +import contextlib +import logging + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamableHttp import ( + StreamableHTTPServerTransport, +) +from starlette.applications import Starlette +from starlette.routing import Mount + +logger = logging.getLogger(__name__) +# Global task group that will be initialized in the lifespan +task_group = None + + +@contextlib.asynccontextmanager +async def lifespan(app): + """Application lifespan context manager for managing task group.""" + global task_group + + async with anyio.create_task_group() as tg: + task_group = tg + logger.info("Application started, task group initialized!") + try: + yield + finally: + logger.info("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + logger.info("Resources cleaned up successfully.") + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server("mcp-streamable-http-stateless-demo") + + @app.call_tool() + async def call_tool( + name: str, arguments: dict + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ctx = app.request_context + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i+1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return [ + types.TextContent( + type="text", + text=( + f"Sent {count} notifications with {interval}s interval" + f" for caller: {caller}" + ), + ) + ] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="start-notification-stream", + description=( + "Sends a stream of notifications with configurable count" + " and interval" + ), + inputSchema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ( + "Identifier of the caller to include in notifications" + ), + }, + }, + }, + ) + ] + + # ASGI handler for stateless HTTP connections + async def handle_streamable_http(scope, receive, send): + logger.debug("Creating new transport") + # Use lock to prevent race conditions when creating new sessions + http_transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=json_response, + ) + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + if not task_group: + raise RuntimeError("Task group is not initialized") + + async def run_server(): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + # This allows the server to run without waiting for initialization + require_initialization=False, + ) + + # Start server task + task_group.start_soon(run_server) + + # Small delay to allow the server task to start + # This helps prevent race conditions in stateless mode + await anyio.sleep(0.001) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + + # Create an ASGI application using the transport + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + + return 0 diff --git a/examples/servers/simple-streamablehttp-stateless/pyproject.toml b/examples/servers/simple-streamablehttp-stateless/pyproject.toml new file mode 100644 index 00000000..d2b08945 --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-simple-streamablehttp-stateless" +version = "0.1.0" +description = "A simple MCP server exposing a StreamableHttp transport in stateless mode" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable", "stateless"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-streamablehttp-stateless = "mcp_simple_streamablehttp_stateless.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_streamablehttp_stateless"] + +[tool.pyright] +include = ["mcp_simple_streamablehttp_stateless"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dbaff305..4f33af19 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -479,11 +479,20 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, + # When True, the server will wait for the client to send an initialization + # message before processing any other messages. + # False should be used for stateless servers. + require_initialization: bool = True, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) session = await stack.enter_async_context( - ServerSession(read_stream, write_stream, initialization_options) + ServerSession( + read_stream, + write_stream, + initialization_options, + require_initialization, + ) ) async with anyio.create_task_group() as tg: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 3a1f210d..ea9d3ec8 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -85,11 +85,15 @@ def __init__( read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, + require_initialization: bool = True, ) -> None: super().__init__( read_stream, write_stream, types.ClientRequest, types.ClientNotification ) - self._initialization_state = InitializationState.NotInitialized + if require_initialization: + self._initialization_state = InitializationState.NotInitialized + else: + self._initialization_state = InitializationState.Initialized self._init_options = init_options self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ServerRequestResponder](0) @@ -171,6 +175,7 @@ async def _received_notification( await anyio.lowlevel.checkpoint() match notification.root: case types.InitializedNotification(): + print("INITIALIZED") self._initialization_state = InitializationState.Initialized case _: if self._initialization_state != InitializationState.Initialized: diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 063ad82b..e23059d6 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -541,4 +541,4 @@ def test_json_response(json_response_server, json_server_url): json=INIT_REQUEST, ) assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" \ No newline at end of file + assert response.headers.get("Content-Type") == "application/json" diff --git a/uv.lock b/uv.lock index 6618ea36..44dfdc83 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,7 @@ members = [ "mcp-simple-prompt", "mcp-simple-resource", "mcp-simple-streamablehttp", + "mcp-simple-streamablehttp-stateless", "mcp-simple-tool", ] @@ -487,6 +488,7 @@ wheels = [ [[package]] name = "mcp" +version = "1.6.1.dev22+6c9c320" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -664,6 +666,43 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-streamablehttp-stateless" +version = "0.1.0" +source = { editable = "examples/servers/simple-streamablehttp-stateless" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0" From 92d42875746f8b417b9b981de9ec491e1178c6be Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 17:07:41 +0100 Subject: [PATCH 14/45] format --- tests/server/test_streamableHttp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 063ad82b..e23059d6 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -541,4 +541,4 @@ def test_json_response(json_response_server, json_server_url): json=INIT_REQUEST, ) assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" \ No newline at end of file + assert response.headers.get("Content-Type") == "application/json" From aa9f6e5f3dac9808a830da9ceba17392635a1c42 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 17:25:07 +0100 Subject: [PATCH 15/45] uv lock --- uv.lock | 2 ++ 1 file changed, 2 insertions(+) diff --git a/uv.lock b/uv.lock index 6618ea36..3ea01ff8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" [options] @@ -543,6 +544,7 @@ requires-dist = [ { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] +provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ From 45723eab97027f50071373c470258cb7c5778854 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 17:30:51 +0100 Subject: [PATCH 16/45] simplify readme --- .../simple-streamablehttp-stateless/README.md | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/examples/servers/simple-streamablehttp-stateless/README.md b/examples/servers/simple-streamablehttp-stateless/README.md index e282f4e4..2abb6061 100644 --- a/examples/servers/simple-streamablehttp-stateless/README.md +++ b/examples/servers/simple-streamablehttp-stateless/README.md @@ -10,13 +10,6 @@ A stateless MCP server example demonstrating the StreamableHttp transport withou - Task lifecycle scoped to individual requests - Suitable for deployment in multi-node environments -## Key Differences from Stateful Version - -1. **No Session Management**: The server explicitly sets `mcp_session_id=None` when creating the transport -2. **Request Scoped**: Each request creates its own server instance and task group -3. **Immediate Cleanup**: Resources are cleaned up after each request completes -4. **Rejcts Session IDs**: If a client sends a session ID, the server rejects it with a BAD_REQUEST -5. **Stateless Deployments**: Can be deployed to multiple nodes behind a load balancer ## Usage @@ -42,20 +35,6 @@ The server exposes a tool named "start-notification-stream" that accepts three a - `count`: Number of notifications to send (e.g., 5) - `caller`: Identifier string for the caller -## Client Considerations - -When connecting to a stateless server: -1. Do not send the `X-MCP-Session-ID` header -2. Each request is independent with no shared state -3. Suitable for one-shot operations or when state can be maintained client-side -4. Works well with load balancers that distribute requests across multiple instances - -## Deployment Benefits - -1. **Horizontal Scaling**: Deploy multiple instances behind a load balancer -2. **No Session Affinity**: Requests can be routed to any instance -3. **Simplified Infrastructure**: No session storage or sticky sessions required -4. **Cloud Native**: Works well in containerized and serverless environments ## Client From 6b7a616a6e5a51b1ba3102d10693c0dee28bf1d2 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 17:31:16 +0100 Subject: [PATCH 17/45] clean up --- src/mcp/server/session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ea9d3ec8..28f5ddaf 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -175,7 +175,6 @@ async def _received_notification( await anyio.lowlevel.checkpoint() match notification.root: case types.InitializedNotification(): - print("INITIALIZED") self._initialization_state = InitializationState.Initialized case _: if self._initialization_state != InitializationState.Initialized: From b1be6913953bebcecd8490b67c542c768321b465 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Tue, 22 Apr 2025 09:56:21 +0100 Subject: [PATCH 18/45] get sse --- .../mcp_simple_streamablehttp/server.py | 4 + src/mcp/server/streamableHttp.py | 117 +++++++++++++++--- tests/server/test_streamableHttp.py | 89 +++++++++++++ 3 files changed, 190 insertions(+), 20 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index eec5edb4..59f51263 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -11,6 +11,7 @@ MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport, ) +from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -87,6 +88,9 @@ async def call_tool( if i < count - 1: # Don't wait after the last notification await anyio.sleep(interval) + # This will send a resource notificaiton though standalone SSE + # established by GET request + await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource")) return [ types.TextContent( type="text", diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 09b94395..4e59886c 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -43,6 +43,9 @@ CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_SSE = "text/event-stream" +# Special key for the standalone GET stream +GET_STREAM_KEY = "_GET_stream" + # Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) # Pattern ensures entire string contains only valid characters by using ^ and $ anchors SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") @@ -476,10 +479,19 @@ async def sse_writer(): return async def _handle_get_request(self, request: Request, send: Send) -> None: - """Handle GET requests for SSE stream establishment.""" - # Validate session ID if server has one - if not await self._validate_session(request, send): - return + """ + Handle GET request to establish SSE. + + This allows the server to communicate to the client without the client + first sending data via HTTP POST. The server can send JSON-RPC requests + and notifications on this stream. + """ + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) @@ -491,13 +503,80 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - # TODO: Implement SSE stream for GET requests - # For now, return 405 Method Not Allowed - response = self._create_server_response( - "SSE stream from GET request not implemented yet", - HTTPStatus.METHOD_NOT_ALLOWED, + if not await self._validate_session(request, send): + return + + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Check if we already have an active GET stream + if GET_STREAM_KEY in self._request_streams: + response = self._create_server_response( + "Conflict: Only one SSE stream is allowed per session", + HTTPStatus.CONFLICT, + ) + await response(request.scope, request.receive, send) + return + + # Create SSE stream + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, Any] + ](0) + + async def standalone_sse_writer(): + try: + # Create a standalone message stream for server-initiated messages + standalone_stream_writer, standalone_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Register this stream using the special key + self._request_streams[GET_STREAM_KEY] = standalone_stream_writer + + async with sse_stream_writer, standalone_stream_reader: + # Process messages from the standalone stream + async for received_message in standalone_stream_reader: + # For the standalone stream, we handle: + # - JSONRPCNotification (server can send notifications to client) + # - JSONRPCRequest (server can send requests to client) + # We should NOT receive JSONRPCResponse + + # Send the message via SSE + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + except Exception as e: + logger.exception(f"Error in standalone SSE writer: {e}") + finally: + logger.debug("Closing standalone SSE writer") + # Remove the stream from request_streams + self._request_streams.pop(GET_STREAM_KEY, None) + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=standalone_sse_writer, + headers=headers, ) - await response(request.scope, request.receive, send) + + try: + # This will send headers immediately and establish the SSE connection + await response(request.scope, request.receive, send) + except Exception as e: + logger.exception(f"Error in standalone SSE response: {e}") + # Clean up the request stream + self._request_streams.pop(GET_STREAM_KEY, None) async def _handle_delete_request(self, request: Request, send: Send) -> None: """Handle DELETE requests for explicit session termination.""" @@ -639,22 +718,20 @@ async def message_router(): # For responses, route based on the request ID if isinstance(message.root, JSONRPCResponse): target_request_id = str(message.root.id) - # For notifications, route by related_request_id if available - elif isinstance(message.root, JSONRPCNotification): - # Get related_request_id from params + # For notifications and requests, handle routing logic + elif isinstance( + message.root, JSONRPCNotification + ) or isinstance(message.root, JSONRPCRequest): params = message.root.params if params and "related_request_id" in params: related_id = params.get("related_request_id") if related_id is not None: target_request_id = str(related_id) - # Send to the specific request stream if available - if ( - target_request_id - and target_request_id in self._request_streams - ): + request_stream_id = target_request_id or GET_STREAM_KEY + if request_stream_id in self._request_streams: try: - await self._request_streams[target_request_id].send( + await self._request_streams[request_stream_id].send( message ) except ( @@ -662,7 +739,7 @@ async def message_router(): anyio.ClosedResourceError, ): # Stream might be closed, remove from registry - self._request_streams.pop(target_request_id, None) + self._request_streams.pop(request_stream_id, None) except Exception as e: logger.exception(f"Error in message router: {e}") diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index e23059d6..dc576d31 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -542,3 +542,92 @@ def test_json_response(json_response_server, json_server_url): ) assert response.status_code == 200 assert response.headers.get("Content-Type") == "application/json" + + +def test_get_sse_stream(basic_server, basic_server_url): + """Test establishing an SSE stream via GET request.""" + # First, we need to initialize a session + mcp_url = f"{basic_server_url}/mcp" + init_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Now attempt to establish an SSE stream via GET + get_response = requests.get( + mcp_url, + headers={ + "Accept": "text/event-stream", + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" + + # Test that a second GET request gets rejected (only one stream allowed) + second_get = requests.get( + mcp_url, + headers={ + "Accept": "text/event-stream", + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + + # Should get CONFLICT (409) since there's already a stream + # Note: This might fail if the first stream fully closed before this runs, + # but generally it should work in the test environment where it runs quickly + assert second_get.status_code == 409 + + +def test_get_validation(basic_server, basic_server_url): + """Test validation for GET requests.""" + # First, we need to initialize a session + mcp_url = f"{basic_server_url}/mcp" + init_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test without Accept header + response = requests.get( + mcp_url, + headers={ + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = requests.get( + mcp_url, + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text From 201ec99ce66492db8beaca29802dc6ca1dcbaf80 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Tue, 22 Apr 2025 10:33:08 +0100 Subject: [PATCH 19/45] uv lock --- uv.lock | 1 - 1 file changed, 1 deletion(-) diff --git a/uv.lock b/uv.lock index 5a1b5764..a113ed92 100644 --- a/uv.lock +++ b/uv.lock @@ -489,7 +489,6 @@ wheels = [ [[package]] name = "mcp" -version = "1.6.1.dev22+6c9c320" source = { editable = "." } dependencies = [ { name = "anyio" }, From 46ec72d0ecdf0b6fe01699c7ccff4d2ed65aa31c Mon Sep 17 00:00:00 2001 From: ihrpr Date: Tue, 22 Apr 2025 21:12:07 +0100 Subject: [PATCH 20/45] clean up --- src/mcp/server/streamableHttp.py | 393 +++++++++++++--------------- tests/server/test_streamableHttp.py | 1 - 2 files changed, 179 insertions(+), 215 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 09b94395..34a272f6 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -24,10 +24,17 @@ from starlette.types import Receive, Scope, Send from mcp.types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + INVALID_REQUEST, + PARSE_ERROR, + ErrorData, + JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + RequestId, ) logger = logging.getLogger(__name__) @@ -61,8 +68,6 @@ class StreamableHTTPServerTransport: None ) _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None - # Dictionary to track request-specific message streams - _request_streams: dict[str, MemoryObjectSendStream[JSONRPCMessage]] def __init__( self, @@ -90,16 +95,19 @@ def __init__( self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled - self._request_streams = {} + self._request_streams: dict[ + RequestId, MemoryObjectSendStream[JSONRPCMessage] + ] = {} self._terminated = False - def _create_server_response( + def _create_error_response( self, - message: str, + error_message: str, status_code: HTTPStatus, + error_code: int = INVALID_REQUEST, headers: dict[str, str] | None = None, ) -> Response: - """Create a standardized server response.""" + """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: response_headers.update(headers) @@ -107,15 +115,25 @@ def _create_server_response( if self.mcp_session_id: response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + # Return a properly formatted JSON error response + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", # We don't have a request ID for general errors + error=ErrorData( + code=error_code, + message=error_message, + ), + ) + return Response( - message, + error_response.model_dump_json(by_alias=True, exclude_none=True), status_code=status_code, headers=response_headers, ) def _create_json_response( self, - response_message: JSONRPCMessage, + response_message: JSONRPCMessage | None, status_code: HTTPStatus = HTTPStatus.OK, headers: dict[str, str] | None = None, ) -> Response: @@ -128,7 +146,9 @@ def _create_json_response( response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( - response_message.model_dump_json(by_alias=True, exclude_none=True), + response_message.model_dump_json(by_alias=True, exclude_none=True) + if response_message + else None, status_code=status_code, headers=response_headers, ) @@ -142,7 +162,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No request = Request(scope, receive) if self._terminated: # If the session has been terminated, return 404 Not Found - response = self._create_server_response( + response = self._create_error_response( "Not Found: Session has been terminated", HTTPStatus.NOT_FOUND, ) @@ -194,7 +214,7 @@ async def _handle_post_request( # Check Accept headers has_json, has_sse = self._check_accept_headers(request) if not (has_json and has_sse): - response = self._create_server_response( + response = self._create_error_response( ( "Not Acceptable: Client must accept both application/json and " "text/event-stream" @@ -206,7 +226,7 @@ async def _handle_post_request( # Validate Content-Type if not self._check_content_type(request): - response = self._create_server_response( + response = self._create_error_response( "Unsupported Media Type: Content-Type must be application/json", HTTPStatus.UNSUPPORTED_MEDIA_TYPE, ) @@ -216,7 +236,7 @@ async def _handle_post_request( # Parse the body - only read it once body = await request.body() if len(body) > MAXIMUM_MESSAGE_SIZE: - response = self._create_server_response( + response = self._create_error_response( "Payload Too Large: Message exceeds maximum size", HTTPStatus.REQUEST_ENTITY_TOO_LARGE, ) @@ -226,9 +246,8 @@ async def _handle_post_request( try: raw_message = json.loads(body) except json.JSONDecodeError as e: - response = self._create_server_response( - f"Parse error: {str(e)}", - HTTPStatus.BAD_REQUEST, + response = self._create_error_response( + f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR ) await response(scope, receive, send) return @@ -236,9 +255,10 @@ async def _handle_post_request( try: message = JSONRPCMessage.model_validate(raw_message) except ValidationError as e: - response = self._create_server_response( + response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, + INVALID_PARAMS, ) await response(scope, receive, send) return @@ -257,7 +277,7 @@ async def _handle_post_request( # If request has a session ID but doesn't match, return 404 if request_session_id and request_session_id != self.mcp_session_id: - response = self._create_server_response( + response = self._create_error_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, ) @@ -267,13 +287,11 @@ async def _handle_post_request( elif not await self._validate_session(request, send): return - is_request = isinstance(message.root, JSONRPCRequest) - # For notifications and responses only, return 202 Accepted - if not is_request: + if not isinstance(message.root, JSONRPCRequest): # Create response object and send it - response = self._create_server_response( - "Accepted", + response = self._create_json_response( + None, HTTPStatus.ACCEPTED, ) await response(scope, receive, send) @@ -283,192 +301,141 @@ async def _handle_post_request( return - # For requests, determine whether to return JSON or set up SSE stream - if is_request: - if self.is_json_response_enabled: - # JSON response mode - create a response future - request_id = None - if isinstance(message.root, JSONRPCRequest): - request_id = str(message.root.id) - - if not request_id: - # Should not happen for valid JSONRPCRequest, but handle it - response = self._create_server_response( - "Invalid Request: Missing request ID", - HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) - return - - # Create promise stream for getting response - request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) - ) - - # Register this stream for the request ID - self._request_streams[request_id] = request_stream_writer - - # Process the message - await writer.send(message) + # Extract the request ID outside the try block for proper scope + request_id = str(message.root.id) + # Create promise stream for getting response + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) - try: - # Process messages from the request-specific stream - # We need to collect all messages until we get a response - response_message = None - - # Use similar approach to SSE writer for consistency - async for received_message in request_stream_reader: - # If it's a response, this is what we're waiting for - if isinstance(received_message.root, JSONRPCResponse): - response_message = received_message - break - # For notifications, keep waiting for the actual response - elif isinstance(received_message.root, JSONRPCNotification): - # Just process it and continue waiting - logger.debug( - f"Notification: {received_message.root.method}" - ) + # Register this stream for the request ID + self._request_streams[request_id] = request_stream_writer - # At this point we should have a response - if response_message: - # Create JSON response - response = self._create_json_response(response_message) - await response(scope, receive, send) + if self.is_json_response_enabled: + # Process the message + await writer.send(message) + try: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + response_message = None + + # Use similar approach to SSE writer for consistency + async for received_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance( + received_message.root, JSONRPCResponse | JSONRPCError + ): + response_message = received_message + break + # For notifications and request, keep waiting else: - # This shouldn't happen in normal operation - logger.error( - "No response message received before stream closed" - ) - response = self._create_server_response( - "Error processing request: No response received", - HTTPStatus.INTERNAL_SERVER_ERROR, - ) - await response(scope, receive, send) - except Exception as e: - logger.exception(f"Error processing JSON response: {e}") - response = self._create_server_response( - f"Error processing request: {str(e)}", + logger.debug(f"received: {received_message.root.method}") + + # At this point we should have a response + if response_message: + # Create JSON response + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + # This shouldn't happen in normal operation + logger.error( + "No response message received before stream closed" + ) + response = self._create_error_response( + "Error processing request: No response received", HTTPStatus.INTERNAL_SERVER_ERROR, ) await response(scope, receive, send) - finally: - # Clean up the request stream - if request_id in self._request_streams: - self._request_streams.pop(request_id, None) - await request_stream_reader.aclose() - await request_stream_writer.aclose() - else: - # SSE stream mode (original behavior) - # Set up headers - headers = { - "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", - "Content-Type": CONTENT_TYPE_SSE, - } - - if self.mcp_session_id: - headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - # Create SSE stream - sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, Any]](0) - ) - - async def sse_writer(): - # Get the request ID from the incoming request message - request_id = None - try: - # Create a request-specific message stream for this POST - request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) - ) - - if isinstance(message.root, JSONRPCRequest): - request_id = str(message.root.id) - # Register this stream for the request ID - if request_id: - self._request_streams[request_id] = ( - request_stream_writer - ) - - async with sse_stream_writer, request_stream_reader: - # Process messages from the request-specific stream - async for received_message in request_stream_reader: - # Send the message via SSE - related_request_id = None - - if isinstance( - received_message.root, JSONRPCNotification - ): - # Get related_request_id from params - params = received_message.root.params - if params and "related_request_id" in params: - related_request_id = params.get( - "related_request_id" - ) - logger.debug( - f"NOTIFICATION: {related_request_id}, " - f"{params.get('data')}" - ) - - # Build the event data - event_data = { - "event": "message", - "data": received_message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - - await sse_stream_writer.send(event_data) - - # If response, remove from pending streams and close - if isinstance( - received_message.root, JSONRPCResponse - ): - if request_id: - self._request_streams.pop(request_id, None) - break - except Exception as e: - logger.exception(f"Error in SSE writer: {e}") - finally: - logger.debug("Closing SSE writer") - # Clean up the request-specific streams - if request_id and request_id in self._request_streams: - self._request_streams.pop(request_id, None) - - # Create and start EventSourceResponse - response = EventSourceResponse( - content=sse_stream_reader, - data_sender_callable=sse_writer, - headers=headers, + except Exception as e: + logger.exception(f"Error processing JSON response: {e}") + response = self._create_error_response( + f"Error processing request: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, ) + await response(scope, receive, send) + finally: + # Clean up the request stream + if request_id in self._request_streams: + self._request_streams.pop(request_id, None) + await request_stream_reader.aclose() + await request_stream_writer.aclose() + else: + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, Any]](0) + ) - # Extract the request ID outside the try block for proper scope - outer_request_id = None - if isinstance(message.root, JSONRPCRequest): - outer_request_id = str(message.root.id) - - # Start the SSE response (this will send headers immediately) + async def sse_writer(): + # Get the request ID from the incoming request message try: - # First send the response to establish the SSE connection - async with anyio.create_task_group() as tg: - tg.start_soon(response, scope, receive, send) + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for received_message in request_stream_reader: + # Build the event data + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance( + received_message.root, + JSONRPCResponse | JSONRPCError, + ): + if request_id: + self._request_streams.pop(request_id, None) + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + # Clean up the request-specific streams + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) - # Then send the message to be processed by the server - await writer.send(message) - except Exception: - logger.exception("SSE response error") - # Clean up the request stream if something goes wrong - if ( - outer_request_id - and outer_request_id in self._request_streams - ): - self._request_streams.pop(outer_request_id, None) + # Create and start EventSourceResponse + # SSE stream mode (original behavior) + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + **( + {MCP_SESSION_ID_HEADER: self.mcp_session_id} + if self.mcp_session_id + else {} + ), + } + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + # Then send the message to be processed by the server + await writer.send(message) + except Exception: + logger.exception("SSE response error") + # Clean up the request stream if something goes wrong + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) except Exception as err: logger.exception("Error handling POST request") - response = self._create_server_response( + response = self._create_error_response( f"Error handling POST request: {err}", HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, ) await response(scope, receive, send) if writer: @@ -484,7 +451,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: _, has_sse = self._check_accept_headers(request) if not has_sse: - response = self._create_server_response( + response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, ) @@ -493,7 +460,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # TODO: Implement SSE stream for GET requests # For now, return 405 Method Not Allowed - response = self._create_server_response( + response = self._create_error_response( "SSE stream from GET request not implemented yet", HTTPStatus.METHOD_NOT_ALLOWED, ) @@ -504,7 +471,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: # Validate session ID if not self.mcp_session_id: # If no session ID set, return Method Not Allowed - response = self._create_server_response( + response = self._create_error_response( "Method Not Allowed: Session termination not supported", HTTPStatus.METHOD_NOT_ALLOWED, ) @@ -516,8 +483,8 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: self._terminate_session() - response = self._create_server_response( - "Session terminated", + response = self._create_json_response( + None, HTTPStatus.OK, ) await response(request.scope, request.receive, send) @@ -557,9 +524,9 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non if self.mcp_session_id: headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - response = Response( + response = self._create_error_response( "Method Not Allowed", - status_code=HTTPStatus.METHOD_NOT_ALLOWED, + HTTPStatus.METHOD_NOT_ALLOWED, headers=headers, ) await response(request.scope, request.receive, send) @@ -575,7 +542,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # If no session ID provided but required, return error if not request_session_id: - response = self._create_server_response( + response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, ) @@ -584,7 +551,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # If session ID doesn't match, return error if request_session_id != self.mcp_session_id: - response = self._create_server_response( + response = self._create_error_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, ) @@ -635,18 +602,16 @@ async def message_router(): async for message in write_stream_reader: # Determine which request stream(s) should receive this message target_request_id = None - - # For responses, route based on the request ID - if isinstance(message.root, JSONRPCResponse): + if isinstance( + message.root, JSONRPCNotification | JSONRPCRequest + ): + # Extract related_request_id from params if it exists + if (params := getattr(message.root, "params", None)) and ( + related_id := params.get("related_request_id") + ) is not None: + target_request_id = str(related_id) + else: target_request_id = str(message.root.id) - # For notifications, route by related_request_id if available - elif isinstance(message.root, JSONRPCNotification): - # Get related_request_id from params - params = message.root.params - if params and "related_request_id" in params: - related_id = params.get("related_request_id") - if related_id is not None: - target_request_id = str(related_id) # Send to the specific request stream if available if ( diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index e23059d6..8904bf4f 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -482,7 +482,6 @@ def test_session_termination(basic_server, basic_server_url): headers={MCP_SESSION_ID_HEADER: session_id}, ) assert response.status_code == 200 - assert "Session terminated" in response.text # Try to use the terminated session response = requests.post( From c2be5afdb306a78d78c1d5f538e26a5a9e09c305 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 09:35:33 +0100 Subject: [PATCH 21/45] streamable http client --- src/mcp/client/streamableHttp.py | 266 ++++++++++++++++++ .../{server => shared}/test_streamableHttp.py | 185 +++++++++++- 2 files changed, 438 insertions(+), 13 deletions(-) create mode 100644 src/mcp/client/streamableHttp.py rename tests/{server => shared}/test_streamableHttp.py (73%) diff --git a/src/mcp/client/streamableHttp.py b/src/mcp/client/streamableHttp.py new file mode 100644 index 00000000..d8453d77 --- /dev/null +++ b/src/mcp/client/streamableHttp.py @@ -0,0 +1,266 @@ +""" +StreamableHTTP Client Transport Module + +This module implements the StreamableHTTP transport for MCP clients, +providing support for HTTP POST requests with optional SSE streaming responses +and session management. +""" + +import logging +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx_sse import EventSource, aconnect_sse + +from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest + +logger = logging.getLogger(__name__) + +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + + +@asynccontextmanager +async def streamablehttp_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: float = 30, + sse_read_timeout: float = 60 * 5, +): + """ + Client transport for StreamableHTTP. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + """ + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + async with anyio.create_task_group() as tg: + try: + logger.info(f"Connecting to StreamableHTTP endpoint: {url}") + # Set up headers with required Accept header + request_headers = { + "Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}", + "Content-Type": CONTENT_TYPE_JSON, + **(headers or {}), + } + + # Track session ID if provided by server + session_id: str | None = None + + async with httpx.AsyncClient( + headers=request_headers, timeout=timeout, follow_redirects=True + ) as client: + + async def post_writer(): + nonlocal session_id + try: + async with write_stream_reader: + async for message in write_stream_reader: + # Add session ID to headers if we have one + post_headers = request_headers.copy() + if session_id: + post_headers[MCP_SESSION_ID_HEADER] = session_id + + logger.debug(f"Sending client message: {message}") + + # Handle initial initialization request + is_initialization = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + if ( + isinstance(message.root, JSONRPCNotification) + and message.root.method + == "notifications/initialized" + ): + tg.start_soon(get_stream) + + async with client.stream( + "POST", + url, + json=message.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + headers=post_headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + continue + # Check for 404 (session expired/invalid) + if response.status_code == 404: + if is_initialization and session_id: + logger.info( + "Session expired, retrying without ID" + ) + session_id = None + post_headers.pop( + MCP_SESSION_ID_HEADER, None + ) + # Retry with client.stream + async with client.stream( + "POST", + url, + json=message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + headers=post_headers, + ) as new_response: + response = new_response + else: + response.raise_for_status() + + response.raise_for_status() + + # Extract session ID from response headers + if is_initialization: + new_session_id = response.headers.get( + MCP_SESSION_ID_HEADER + ) + if new_session_id: + session_id = new_session_id + logger.info( + f"Received session ID: {session_id}" + ) + + # Handle different response types + content_type = response.headers.get( + "content-type", "" + ).lower() + + if content_type.startswith(CONTENT_TYPE_JSON): + try: + content = await response.aread() + json_message = ( + JSONRPCMessage.model_validate_json( + content + ) + ) + await read_stream_writer.send(json_message) + except Exception as exc: + logger.error( + f"Error parsing JSON response: {exc}" + ) + await read_stream_writer.send(exc) + + elif content_type.startswith(CONTENT_TYPE_SSE): + # Parse SSE events from the response + try: + event_source = EventSource(response) + async for sse in event_source.aiter_sse(): + if sse.event == "message": + try: + await read_stream_writer.send( + JSONRPCMessage.model_validate_json( + sse.data + ) + ) + except Exception as exc: + logger.exception( + "Error parsing message" + ) + await read_stream_writer.send( + exc + ) + else: + logger.warning( + f"Unknown event: {sse.event}" + ) + + except Exception as e: + logger.exception( + "Error reading SSE stream:" + ) + await read_stream_writer.send(e) + + else: + # For 202 Accepted with no body + if response.status_code == 202: + logger.debug("Received 202 Accepted") + continue + + error_msg = ( + f"Unexpected content type: {content_type}" + ) + logger.error(error_msg) + await read_stream_writer.send( + ValueError(error_msg) + ) + + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + async def get_stream(): + """ + Optional GET stream for server-initiated messages + """ + nonlocal session_id + try: + # Only attempt GET if we have a session ID + if not session_id: + return + + get_headers = request_headers.copy() + get_headers[MCP_SESSION_ID_HEADER] = session_id + + async with aconnect_sse( + client, "GET", url, headers=get_headers + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + async for sse in event_source.aiter_sse(): + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json( + sse.data + ) + logger.debug(f"GET message: {message}") + await read_stream_writer.send(message) + except Exception as exc: + logger.error( + f"Error parsing GET message: {exc}" + ) + await read_stream_writer.send(exc) + else: + logger.warning( + f"Unknown SSE event from GET: {sse.event}" + ) + except Exception as exc: + # GET stream is optional, so don't propagate errors + logger.debug(f"GET stream error (non-fatal): {exc}") + + tg.start_soon(post_writer) + + try: + yield read_stream, write_stream + finally: + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/tests/server/test_streamableHttp.py b/tests/shared/test_streamableHttp.py similarity index 73% rename from tests/server/test_streamableHttp.py rename to tests/shared/test_streamableHttp.py index f612575c..416be8e0 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/shared/test_streamableHttp.py @@ -1,7 +1,7 @@ """ -Tests for the StreamableHTTP server transport validation. +Tests for the StreamableHTTP server and client transport. -This file contains tests for request validation in the StreamableHTTP transport. +Contains tests for both server and client sides of the StreamableHTTP transport. """ import contextlib @@ -13,6 +13,7 @@ from uuid import uuid4 import anyio +import httpx import pytest import requests import uvicorn @@ -22,6 +23,8 @@ from starlette.responses import Response from starlette.routing import Mount +from mcp.client.session import ClientSession +from mcp.client.streamableHttp import streamablehttp_client from mcp.server import Server from mcp.server.streamableHttp import ( MCP_SESSION_ID_HEADER, @@ -29,11 +32,7 @@ StreamableHTTPServerTransport, ) from mcp.shared.exceptions import McpError -from mcp.types import ( - ErrorData, - TextContent, - Tool, -) +from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool # Test constants SERVER_NAME = "test_streamable_http_server" @@ -64,11 +63,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: await anyio.sleep(2.0) return f"Slow response from {uri.host}" - raise McpError( - error=ErrorData( - code=404, message="OOPS! no resource with that URI was found" - ) - ) + raise ValueError(f"Unknown resource: {uri}") @self.list_tools() async def handle_list_tools() -> list[Tool]: @@ -77,11 +72,23 @@ async def handle_list_tools() -> list[Tool]: name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}}, - ) + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + ctx = self.request_context + await ctx.session.send_resource_updated( + uri=AnyUrl("http://test_resource") + ) + return [TextContent(type="text", text=f"Called {name}")] @@ -630,3 +637,155 @@ def test_get_validation(basic_server, basic_server_url): ) assert response.status_code == 406 assert "Not Acceptable" in response.text + + +# Client-specific fixtures +@pytest.fixture +async def http_client(basic_server, basic_server_url): + """Create test client matching the SSE test pattern.""" + async with httpx.AsyncClient(base_url=basic_server_url) as client: + yield client + + +@pytest.fixture +async def initialized_client_session(basic_server, basic_server_url): + """Create initialized StreamableHTTP client session.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + + +@pytest.mark.anyio +async def test_streamablehttp_client_basic_connection(basic_server, basic_server_url): + """Test basic client connection with initialization.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as streams: + async with ClientSession(*streams) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + +@pytest.mark.anyio +async def test_streamablehttp_client_resource_read(initialized_client_session): + """Test client resource read functionality.""" + response = await initialized_client_session.read_resource( + uri=AnyUrl("foobar://test-resource") + ) + assert len(response.contents) == 1 + assert response.contents[0].uri == AnyUrl("foobar://test-resource") + assert response.contents[0].text == "Read test-resource" + + +@pytest.mark.anyio +async def test_streamablehttp_client_tool_invocation(initialized_client_session): + """Test client tool invocation.""" + # First list tools + tools = await initialized_client_session.list_tools() + assert len(tools.tools) == 2 + assert tools.tools[0].name == "test_tool" + + # Call the tool + result = await initialized_client_session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" + + +@pytest.mark.anyio +async def test_streamablehttp_client_error_handling(initialized_client_session): + """Test error handling in client.""" + with pytest.raises(McpError) as exc_info: + await initialized_client_session.read_resource( + uri=AnyUrl("unknown://test-error") + ) + assert exc_info.value.error.code == 0 + assert "Unknown resource: unknown://test-error" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_streamablehttp_client_session_persistence( + basic_server, basic_server_url +): + """Test that session ID persists across requests.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as streams: + async with ClientSession(*streams) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # Read a resource + resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" + + +@pytest.mark.anyio +async def test_streamablehttp_client_json_response( + json_response_server, json_server_url +): + """Test client with JSON response mode.""" + async with streamablehttp_client(f"{json_server_url}/mcp") as streams: + async with ClientSession(*streams) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" + + +@pytest.mark.anyio +async def test_streamablehttp_client_get_stream(basic_server, basic_server_url): + """Test GET stream functionality for server-initiated messages.""" + import mcp.types as types + from mcp.shared.session import RequestResponder + + notifications_received = [] + + # Define message handler to capture notifications + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + notifications_received.append(message) + + async with streamablehttp_client(f"{basic_server_url}/mcp") as streams: + async with ClientSession(*streams, message_handler=message_handler) as session: + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) + + # Verify we received the notification + assert len(notifications_received) > 0 + + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif.root, types.ResourceUpdatedNotification): + assert str(notif.root.params.uri) == "http://test_resource/" + resource_update_found = True + + assert ( + resource_update_found + ), "ResourceUpdatedNotification not received via GET stream" From 9b096dc558f02d4a903e9215037d302e9554d62a Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 09:44:58 +0100 Subject: [PATCH 22/45] add comments to server example where we use related_request_id --- .../mcp_simple_streamablehttp/server.py | 5 +++++ src/mcp/server/fastmcp/server.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index eec5edb4..e7bc4430 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -82,6 +82,11 @@ async def call_tool( level="info", data=f"Notification {i+1}/{count} from caller: {caller}", logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) related_request_id=ctx.request_id, ) if i < count - 1: # Don't wait after the last notification diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index f3bb2586..008b235f 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -466,6 +466,7 @@ async def run_stdio_async(self) -> None: async def run_sse_async(self) -> None: """Run the server using SSE transport.""" import uvicorn + starlette_app = self.sse_app() config = uvicorn.Config( @@ -673,7 +674,10 @@ async def log( **extra: Additional structured data to include """ await self.request_context.session.send_log_message( - level=level, data=message, logger=logger_name + level=level, + data=message, + logger=logger_name, + related_request_id=self.request_id, ) @property From a0a9c5b4e5f7c6b39118405a3230b9e2bc66175e Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 10:18:06 +0100 Subject: [PATCH 23/45] small fixes --- examples/servers/simple-streamablehttp/README.md | 1 + src/mcp/server/streamableHttp.py | 5 ----- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index 5125c3eb..e5aaa652 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -5,6 +5,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en ## Features - Uses the StreamableHTTP transport for server-client communication +- Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint - Task management with anyio task groups - Ability to send multiple notifications over time to the client - Proper resource cleanup and lifespan management diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 34a272f6..2d536a40 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -577,11 +577,6 @@ async def connect( """ # Create the memory streams for this connection - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] - - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream[ JSONRPCMessage | Exception From a5ac2e09df6df2eda17ea73559004a789a3d1f3c Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 11:04:38 +0100 Subject: [PATCH 24/45] use mta field for related_request_id --- src/mcp/server/streamableHttp.py | 11 +++++++---- src/mcp/shared/session.py | 19 +++++++++++++++++-- src/mcp/types.py | 1 - tests/client/test_logging_callback.py | 12 +++++++++--- tests/server/fastmcp/test_server.py | 22 ++++++++++++++++++---- 5 files changed, 51 insertions(+), 14 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 2d536a40..2e0f7090 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -600,10 +600,13 @@ async def message_router(): if isinstance( message.root, JSONRPCNotification | JSONRPCRequest ): - # Extract related_request_id from params if it exists - if (params := getattr(message.root, "params", None)) and ( - related_id := params.get("related_request_id") - ) is not None: + # Extract related_request_id from meta if it exists + if ( + (params := getattr(message.root, "params", None)) + and (meta := params.get("_meta")) + and (related_id := meta.get("related_request_id")) + is not None + ): target_request_id = str(related_id) else: target_request_id = str(message.root.id) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 1017bb98..368524f9 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -6,7 +6,6 @@ from typing import Any, Generic, TypeVar import anyio -import anyio.lowlevel import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel @@ -24,6 +23,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + NotificationParams, RequestParams, ServerNotification, ServerRequest, @@ -276,8 +276,23 @@ async def send_notification( Emits a notification, which is a one-way message that does not expect a response. """ + # Some transport implementations may need to set the related_request_id + # to attribute to the notifications to the request that triggered + # them. + # Update notification meta with related request ID if provided if related_request_id is not None and notification.root.params is not None: - notification.root.params.related_request_id = related_request_id + # Create meta if it doesn't exist + if notification.root.params.meta is None: + # Create meta dict with related_request_id + meta_dict = {"related_request_id": related_request_id} + + else: + # Update existing meta with model_validate to properly handle extra fields + meta_dict = notification.root.params.meta.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + meta_dict["related_request_id"] = related_request_id + notification.root.params.meta = NotificationParams.Meta(**meta_dict) jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), diff --git a/src/mcp/types.py b/src/mcp/types.py index 30500e31..bd71d51f 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -58,7 +58,6 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) - related_request_id: RequestId | None = None """ This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 797f817e..588fa649 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -9,6 +9,7 @@ from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, + NotificationParams, TextContent, ) @@ -78,6 +79,11 @@ async def message_handler( ) assert log_result.isError is False assert len(logging_collector.log_messages) == 1 - assert logging_collector.log_messages[0] == LoggingMessageNotificationParams( - level="info", logger="test_logger", data="Test log message" - ) + # Create meta object with related_request_id added dynamically + meta = NotificationParams.Meta() + setattr(meta, "related_request_id", "2") + log = logging_collector.log_messages[0] + assert log.level == "info" + assert log.logger == "test_logger" + assert log.data == "Test log message" + assert log.meta == meta diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index e76e59c5..772c4152 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -544,14 +544,28 @@ async def logging_tool(msg: str, ctx: Context) -> str: assert mock_log.call_count == 4 mock_log.assert_any_call( - level="debug", data="Debug message", logger=None + level="debug", + data="Debug message", + logger=None, + related_request_id="1", ) - mock_log.assert_any_call(level="info", data="Info message", logger=None) mock_log.assert_any_call( - level="warning", data="Warning message", logger=None + level="info", + data="Info message", + logger=None, + related_request_id="1", ) mock_log.assert_any_call( - level="error", data="Error message", logger=None + level="warning", + data="Warning message", + logger=None, + related_request_id="1", + ) + mock_log.assert_any_call( + level="error", + data="Error message", + logger=None, + related_request_id="1", ) @pytest.mark.anyio From 2e615f36b7e515ad9efaf2cff46bfc1a2fa00f46 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 11:10:20 +0100 Subject: [PATCH 25/45] unrelated test and format --- src/mcp/shared/session.py | 6 +----- tests/issues/test_188_concurrency.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 368524f9..3a01cb04 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -277,17 +277,13 @@ async def send_notification( a response. """ # Some transport implementations may need to set the related_request_id - # to attribute to the notifications to the request that triggered - # them. - # Update notification meta with related request ID if provided + # to attribute to the notifications to the request that triggered them. if related_request_id is not None and notification.root.params is not None: # Create meta if it doesn't exist if notification.root.params.meta is None: - # Create meta dict with related_request_id meta_dict = {"related_request_id": related_request_id} else: - # Update existing meta with model_validate to properly handle extra fields meta_dict = notification.root.params.meta.model_dump( by_alias=True, mode="json", exclude_none=True ) diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 2aa6c49c..d0a86885 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -35,7 +35,7 @@ async def slow_resource(): end_time = anyio.current_time() duration = end_time - start_time - assert duration < 3 * _sleep_time_seconds + assert duration < 6 * _sleep_time_seconds print(duration) From 110526d546b38bd7a4dc1479ea5e46a13edea629 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 11:14:56 +0100 Subject: [PATCH 26/45] clean up --- src/mcp/client/streamableHttp.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/mcp/client/streamableHttp.py b/src/mcp/client/streamableHttp.py index d8453d77..f1e996fa 100644 --- a/src/mcp/client/streamableHttp.py +++ b/src/mcp/client/streamableHttp.py @@ -12,7 +12,6 @@ import anyio import httpx -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, aconnect_sse from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest @@ -41,11 +40,6 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] - - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream[ JSONRPCMessage | Exception From 7ffd5bab401852294aa0306c4e50d4f7a0d7ef89 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 13:49:21 +0100 Subject: [PATCH 27/45] terminate session --- src/mcp/client/streamableHttp.py | 63 +++++++++++++++++++--- tests/shared/test_streamableHttp.py | 84 +++++++++++++++++++++++++---- 2 files changed, 130 insertions(+), 17 deletions(-) diff --git a/src/mcp/client/streamableHttp.py b/src/mcp/client/streamableHttp.py index f1e996fa..f72cd74d 100644 --- a/src/mcp/client/streamableHttp.py +++ b/src/mcp/client/streamableHttp.py @@ -14,7 +14,13 @@ import httpx from httpx_sse import EventSource, aconnect_sse -from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest +from mcp.types import ( + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, +) logger = logging.getLogger(__name__) @@ -39,6 +45,9 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Yields: + Tuple of (read_stream, write_stream, terminate_callback) """ read_stream_writer, read_stream = anyio.create_memory_object_stream[ @@ -122,9 +131,19 @@ async def post_writer(): headers=post_headers, ) as new_response: response = new_response - else: - response.raise_for_status() - + elif isinstance(message.root, JSONRPCRequest): + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=message.root.id, + error=ErrorData( + code=32600, + message="Session terminated", + ), + ) + await read_stream_writer.send( + JSONRPCMessage(jsonrpc_error) + ) + continue response.raise_for_status() # Extract session ID from response headers @@ -204,7 +223,6 @@ async def post_writer(): except Exception as exc: logger.error(f"Error in post_writer: {exc}") - await read_stream_writer.send(exc) finally: await read_stream_writer.aclose() await write_stream.aclose() @@ -223,7 +241,11 @@ async def get_stream(): get_headers[MCP_SESSION_ID_HEADER] = session_id async with aconnect_sse( - client, "GET", url, headers=get_headers + client, + "GET", + url, + headers=get_headers, + timeout=httpx.Timeout(timeout, read=sse_read_timeout), ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") @@ -251,8 +273,35 @@ async def get_stream(): tg.start_soon(post_writer) + async def terminate_session(): + """ + Terminate the session by sending a DELETE request. + """ + nonlocal session_id + if not session_id: + return # No session to terminate + + try: + delete_headers = request_headers.copy() + delete_headers[MCP_SESSION_ID_HEADER] = session_id + + response = await client.delete( + url, + headers=delete_headers, + ) + + if response.status_code == 405: + # Server doesn't allow client-initiated termination + logger.debug("Server does not allow session termination") + elif response.status_code != 200: + logger.warning( + f"Session termination failed: {response.status_code}" + ) + except Exception as exc: + logger.warning(f"Session termination failed: {exc}") + try: - yield read_stream, write_stream + yield read_stream, write_stream, terminate_session finally: tg.cancel_scope.cancel() finally: diff --git a/tests/shared/test_streamableHttp.py b/tests/shared/test_streamableHttp.py index 416be8e0..aef0bc96 100644 --- a/tests/shared/test_streamableHttp.py +++ b/tests/shared/test_streamableHttp.py @@ -650,8 +650,15 @@ async def http_client(basic_server, basic_server_url): @pytest.fixture async def initialized_client_session(basic_server, basic_server_url): """Create initialized StreamableHTTP client session.""" - async with streamablehttp_client(f"{basic_server_url}/mcp") as streams: - async with ClientSession(*streams) as session: + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: await session.initialize() yield session @@ -659,8 +666,15 @@ async def initialized_client_session(basic_server, basic_server_url): @pytest.mark.anyio async def test_streamablehttp_client_basic_connection(basic_server, basic_server_url): """Test basic client connection with initialization.""" - async with streamablehttp_client(f"{basic_server_url}/mcp") as streams: - async with ClientSession(*streams) as session: + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) @@ -709,8 +723,15 @@ async def test_streamablehttp_client_session_persistence( basic_server, basic_server_url ): """Test that session ID persists across requests.""" - async with streamablehttp_client(f"{basic_server_url}/mcp") as streams: - async with ClientSession(*streams) as session: + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) @@ -732,8 +753,15 @@ async def test_streamablehttp_client_json_response( json_response_server, json_server_url ): """Test client with JSON response mode.""" - async with streamablehttp_client(f"{json_server_url}/mcp") as streams: - async with ClientSession(*streams) as session: + async with streamablehttp_client(f"{json_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) @@ -767,8 +795,14 @@ async def message_handler( if isinstance(message, types.ServerNotification): notifications_received.append(message) - async with streamablehttp_client(f"{basic_server_url}/mcp") as streams: - async with ClientSession(*streams, message_handler=message_handler) as session: + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: # Initialize the session - this triggers the GET stream setup result = await session.initialize() assert isinstance(result, InitializeResult) @@ -789,3 +823,33 @@ async def message_handler( assert ( resource_update_found ), "ResourceUpdatedNotification not received via GET stream" + + +@pytest.mark.anyio +async def test_streamablehttp_client_session_termination( + basic_server, basic_server_url +): + """Test client session termination functionality.""" + + # Create the streamablehttp_client with a custom httpx client to capture headers + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + terminate_session, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # After exiting ClientSession context, explicitly terminate the session + await terminate_session() + with pytest.raises( + McpError, + match="Session terminated", + ): + await session.list_tools() From 029ec56fec5d342e874afb42522fd17a328c7831 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 13:51:42 +0100 Subject: [PATCH 28/45] format --- src/mcp/server/streamableHttp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index d748cabd..5763fa7c 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -510,7 +510,7 @@ async def standalone_sse_writer(): # Process messages from the standalone stream async for received_message in standalone_stream_reader: # For the standalone stream, we handle: - # - JSONRPCNotification (server can send notifications to client) + # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server can send requests to client) # We should NOT receive JSONRPCResponse From 58745c7d977c708e4241b1c60df03d3c2204698b Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 09:21:22 +0100 Subject: [PATCH 29/45] remove useless sleep --- .../mcp_simple_streamablehttp_stateless/server.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index fe22e7c3..ae4432cb 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -148,10 +148,6 @@ async def run_server(): # Start server task task_group.start_soon(run_server) - # Small delay to allow the server task to start - # This helps prevent race conditions in stateless mode - await anyio.sleep(0.001) - # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) From 138792978c7249e93a43157a9598a72dd04ed3c4 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 09:42:54 +0100 Subject: [PATCH 30/45] rename require_initialization to standalone_mode --- .../mcp_simple_streamablehttp_stateless/server.py | 5 +++-- src/mcp/server/lowlevel/server.py | 11 ++++++----- src/mcp/server/session.py | 8 ++++---- uv.lock | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index ae4432cb..da8158a9 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -141,8 +141,9 @@ async def run_server(): read_stream, write_stream, app.create_initialization_options(), - # This allows the server to run without waiting for initialization - require_initialization=False, + # Runs in standalone mode for stateless deployments + # where clients perform initialization with any node + standalone_mode=True, ) # Start server task diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 4f33af19..b47f5305 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -479,10 +479,11 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, - # When True, the server will wait for the client to send an initialization - # message before processing any other messages. - # False should be used for stateless servers. - require_initialization: bool = True, + # When True, the server runs in standalone mode for stateless deployments where + # clients can perform initialization with any node. The client must still follow + # the initialization lifecycle, but can do so with any available node + # rather than requiring initialization for each connection. + standalone_mode: bool = False, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) @@ -491,7 +492,7 @@ async def run( read_stream, write_stream, initialization_options, - require_initialization, + standalone_mode=standalone_mode, ) ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 28f5ddaf..07e5a315 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -85,15 +85,15 @@ def __init__( read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, - require_initialization: bool = True, + standalone_mode: bool = False, ) -> None: super().__init__( read_stream, write_stream, types.ClientRequest, types.ClientNotification ) - if require_initialization: - self._initialization_state = InitializationState.NotInitialized - else: + if standalone_mode: self._initialization_state = InitializationState.Initialized + else: + self._initialization_state = InitializationState.NotInitialized self._init_options = init_options self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ServerRequestResponder](0) diff --git a/uv.lock b/uv.lock index 5a1b5764..7eb34e9e 100644 --- a/uv.lock +++ b/uv.lock @@ -489,7 +489,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.6.1.dev22+6c9c320" + source = { editable = "." } dependencies = [ { name = "anyio" }, From 6482120530b9ed70a48cf3c433822a74fd448020 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 10:31:14 +0100 Subject: [PATCH 31/45] remove redundant check for initialize and session --- src/mcp/client/streamableHttp.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/src/mcp/client/streamableHttp.py b/src/mcp/client/streamableHttp.py index f72cd74d..14237d2d 100644 --- a/src/mcp/client/streamableHttp.py +++ b/src/mcp/client/streamableHttp.py @@ -111,27 +111,7 @@ async def post_writer(): continue # Check for 404 (session expired/invalid) if response.status_code == 404: - if is_initialization and session_id: - logger.info( - "Session expired, retrying without ID" - ) - session_id = None - post_headers.pop( - MCP_SESSION_ID_HEADER, None - ) - # Retry with client.stream - async with client.stream( - "POST", - url, - json=message.model_dump( - by_alias=True, - mode="json", - exclude_none=True, - ), - headers=post_headers, - ) as new_response: - response = new_response - elif isinstance(message.root, JSONRPCRequest): + if isinstance(message.root, JSONRPCRequest): jsonrpc_error = JSONRPCError( jsonrpc="2.0", id=message.root.id, From 9a6da2ed1d8ce419fe3fdc9d1fa9b739bf1fa272 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 10:32:24 +0100 Subject: [PATCH 32/45] ruff check --- src/mcp/server/streamableHttp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index eb5e31bd..8faff016 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -510,8 +510,8 @@ async def standalone_sse_writer(): # Process messages from the standalone stream async for received_message in standalone_stream_reader: # For the standalone stream, we handle: - # - JSONRPCNotification (server can send notifications to client) - # - JSONRPCRequest (server can send requests to client) + # - JSONRPCNotification (server sends notifications to client) + # - JSONRPCRequest (server sends requests to client) # We should NOT receive JSONRPCResponse # Send the message via SSE From 3f5fd7ea6bedcf73f4b038ddbfabcbd6e9abbd86 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 10:51:47 +0100 Subject: [PATCH 33/45] support for resumability - server --- .../servers/simple-streamablehttp/README.md | 20 +- .../mcp_simple_streamablehttp/event_store.py | 75 +++++++ .../mcp_simple_streamablehttp/server.py | 22 +- src/mcp/server/streamableHttp.py | 199 +++++++++++++++++- 4 files changed, 306 insertions(+), 10 deletions(-) create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index e5aaa652..f850b728 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -9,6 +9,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en - Task management with anyio task groups - Ability to send multiple notifications over time to the client - Proper resource cleanup and lifespan management +- Resumability support via InMemoryEventStore ## Usage @@ -32,6 +33,23 @@ The server exposes a tool named "start-notification-stream" that accepts three a - `count`: Number of notifications to send (e.g., 5) - `caller`: Identifier string for the caller +## Resumability Support + +This server includes resumability support through the InMemoryEventStore. This enables clients to: + +- Reconnect to the server after a disconnection +- Resume event streaming from where they left off using the Last-Event-ID header + + +The server will: +- Generate unique event IDs for each SSE message +- Store events in memory for later replay +- Replay missed events when a client reconnects with a Last-Event-ID header + +Note: The InMemoryEventStore is designed for demonstration purposes only. For production use, consider implementing a persistent storage solution. + + + ## Client -You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector] \ No newline at end of file +You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use [Inspector](https://github.com/modelcontextprotocol/inspector) \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py new file mode 100644 index 00000000..a887b97a --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -0,0 +1,75 @@ +""" +In-memory event store for demonstrating resumability functionality. + +This is a simple implementation intended for examples and testing, +not for production use where a persistent storage solution would be more appropriate. +""" + +import logging +import time +from collections.abc import Awaitable, Callable +from uuid import uuid4 + +from mcp.server.streamableHttp import EventId, EventStore, StreamId +from mcp.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +class InMemoryEventStore(EventStore): + """ + Simple in-memory implementation of the EventStore interface for resumability. + This is primarily intended for examples and testing, not for production use + where a persistent storage solution would be more appropriate. + """ + + def __init__(self): + self.events: dict[ + str, tuple[str, JSONRPCMessage, float] + ] = {} # event_id -> (stream_id, message, timestamp) + + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage + ) -> EventId: + """Stores an event with a generated event ID.""" + event_id = str(uuid4()) + self.events[event_id] = (stream_id, message, time.time()) + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: Callable[[EventId, JSONRPCMessage], Awaitable[None]], + ) -> StreamId: + """Replays events that occurred after the specified event ID.""" + logger.debug(f"Attempting to replay events after {last_event_id}") + logger.debug(f"Total events in store: {len(self.events)}") + logger.debug(f"Event IDs in store: {list(self.events.keys())}") + + if not last_event_id or last_event_id not in self.events: + logger.warning(f"Event ID {last_event_id} not found in store") + return "" + + # Get the stream ID and timestamp from the last event + stream_id, _, last_timestamp = self.events[last_event_id] + + # Find all events for this stream after the last event + events_to_replay = [ + (event_id, message) + for event_id, (sid, message, timestamp) in self.events.items() + if sid == stream_id and timestamp > last_timestamp + ] + + # Sort by timestamp to ensure chronological order + events_to_replay.sort(key=lambda x: self.events[x[0]][2]) + + logger.debug(f"Found {len(events_to_replay)} events to replay") + logger.debug( + f"Events to replay: {[event_id for event_id, _ in events_to_replay]}" + ) + + # Send all events in order + for event_id, message in events_to_replay: + await send_callback(event_id, message) + + return stream_id diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index b5faffed..fb444c84 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -17,12 +17,24 @@ from starlette.responses import Response from starlette.routing import Mount +from .event_store import InMemoryEventStore + # Configure logging logger = logging.getLogger(__name__) # Global task group that will be initialized in the lifespan task_group = None +# Event store for resumability +# The InMemoryEventStore enables resumability support for StreamableHTTP transport. +# It stores SSE events with unique IDs, allowing clients to: +# 1. Receive event IDs for each SSE message +# 2. Resume streams by sending Last-Event-ID in GET requests +# 3. Replay missed events after reconnection +# Note: This in-memory implementation is for demonstration ONLY. +# For production, use a persistent storage solution. +event_store = InMemoryEventStore() + @contextlib.asynccontextmanager async def lifespan(app): @@ -79,9 +91,14 @@ async def call_tool( # Send the specified number of notifications with the given interval for i in range(count): + # Include more detailed message for resumability demonstration + notification_msg = ( + f"[{i+1}/{count}] Event from '{caller}' - " + f"Use Last-Event-ID to resume if disconnected" + ) await ctx.session.send_log_message( level="info", - data=f"Notification {i+1}/{count} from caller: {caller}", + data=notification_msg, logger="notification_stream", # Associates this notification with the original request # Ensures notifications are sent to the correct response stream @@ -90,6 +107,7 @@ async def call_tool( # - nowhere (if GET request isn't supported) related_request_id=ctx.request_id, ) + logger.debug(f"Sent notification {i+1}/{count} for caller: {caller}") if i < count - 1: # Don't wait after the last notification await anyio.sleep(interval) @@ -163,8 +181,10 @@ async def handle_streamable_http(scope, receive, send): http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, is_json_response_enabled=json_response, + event_store=event_store, # Enable resumability ) server_instances[http_transport.mcp_session_id] = http_transport + logger.info(f"Created new transport with session ID: {new_session_id}") async with http_transport.connect() as streams: read_stream, write_stream = streams diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 8faff016..c2357ba1 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -10,7 +10,8 @@ import json import logging import re -from collections.abc import AsyncGenerator +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager from http import HTTPStatus from typing import Any @@ -57,6 +58,50 @@ # Pattern ensures entire string contains only valid characters by using ^ and $ anchors SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") +# Type aliases +StreamId = str +EventId = str + + +class EventStore(ABC): + """ + Interface for resumability support via event storage. + """ + + @abstractmethod + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage + ) -> EventId: + """ + Stores an event for later retrieval. + + Args: + stream_id: ID of the stream the event belongs to + message: The JSON-RPC message to store + + Returns: + The generated event ID for the stored event + """ + pass + + @abstractmethod + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: Callable[[EventId, JSONRPCMessage], Awaitable[None]], + ) -> StreamId: + """ + Replays events that occurred after the specified event ID. + + Args: + last_event_id: The ID of the last event the client received + send_callback: A callback function to send events to the client + + Returns: + The stream ID of the replayed events + """ + pass + class StreamableHTTPServerTransport: """ @@ -76,6 +121,7 @@ def __init__( self, mcp_session_id: str | None, is_json_response_enabled: bool = False, + event_store: EventStore | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -85,6 +131,9 @@ def __init__( Must contain only visible ASCII characters (0x21-0x7E). is_json_response_enabled: If True, return JSON responses for requests instead of SSE streams. Default is False. + event_store: Event store for resumability support. If provided, + resumability will be enabled, allowing clients to + reconnect and resume messages. Raises: ValueError: If the session ID contains invalid characters. @@ -98,8 +147,9 @@ def __init__( self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled + self._event_store = event_store self._request_streams: dict[ - RequestId, MemoryObjectSendStream[JSONRPCMessage] + RequestId, MemoryObjectSendStream[tuple[JSONRPCMessage, str | None]] ] = {} self._terminated = False @@ -308,7 +358,7 @@ async def _handle_post_request( request_id = str(message.root.id) # Create promise stream for getting response request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) + anyio.create_memory_object_stream[tuple[JSONRPCMessage, str | None]](0) ) # Register this stream for the request ID @@ -323,7 +373,8 @@ async def _handle_post_request( response_message = None # Use similar approach to SSE writer for consistency - async for received_message in request_stream_reader: + async for item in request_stream_reader: + received_message, _ = item # Extract message, ignore event_id # If it's a response, this is what we're waiting for if isinstance( received_message.root, JSONRPCResponse | JSONRPCError @@ -374,7 +425,10 @@ async def sse_writer(): try: async with sse_stream_writer, request_stream_reader: # Process messages from the request-specific stream - async for received_message in request_stream_reader: + async for item in request_stream_reader: + # The message router always sends a tuple of (message, event_id) + received_message, event_id = item + # Build the event data event_data = { "event": "message", @@ -383,6 +437,10 @@ async def sse_writer(): ), } + # If an event ID was provided, include it in the SSE stream + if event_id: + event_data["id"] = event_id + await sse_stream_writer.send(event_data) # If response, remove from pending streams and close @@ -472,6 +530,12 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: if not await self._validate_session(request, send): return + # Handle resumability: check for Last-Event-ID header + if self._event_store: + last_event_id = request.headers.get(LAST_EVENT_ID_HEADER) + if last_event_id: + await self._replay_events(last_event_id, request, send) + return headers = { "Cache-Control": "no-cache, no-transform", @@ -500,7 +564,9 @@ async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages standalone_stream_writer, standalone_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) + anyio.create_memory_object_stream[ + tuple[JSONRPCMessage, str | None] + ](0) ) # Register this stream using the special key @@ -508,7 +574,10 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for received_message in standalone_stream_reader: + async for item in standalone_stream_reader: + # The message router always sends a tuple of (message, event_id) + received_message, event_id = item + # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) @@ -522,6 +591,10 @@ async def standalone_sse_writer(): ), } + # If an event ID was provided, include it in the SSE stream + if event_id: + event_data["id"] = event_id + await sse_stream_writer.send(event_data) except Exception as e: logger.exception(f"Error in standalone SSE writer: {e}") @@ -639,6 +712,102 @@ async def _validate_session(self, request: Request, send: Send) -> bool: return True + async def _replay_events( + self, last_event_id: str, request: Request, send: Send + ) -> None: + """ + Replays events that would have been sent after the specified event ID. + Only used when resumability is enabled. + """ + event_store = self._event_store + if not event_store: + return + + try: + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Create SSE stream for replay + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, Any] + ](0) + + async def replay_sender(): + try: + async with sse_stream_writer: + # Define an async callback for sending events + async def send_event( + event_id: EventId, message: JSONRPCMessage + ) -> None: + print( + "------ REPLAYING EVENT ----------", event_id, message + ) + await sse_stream_writer.send( + { + "event": "message", + "id": event_id, + "data": message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + ) + + # Replay past events and get the stream ID + stream_id = await event_store.replay_events_after( + last_event_id, send_event + ) + + # If stream ID not in mapping, create it + if stream_id and stream_id not in self._request_streams: + msg_writer, msg_reader = anyio.create_memory_object_stream[ + tuple[JSONRPCMessage, str | None] + ](0) + self._request_streams[stream_id] = msg_writer + + # Forward messages to SSE + async with msg_reader: + async for item in msg_reader: + message, event_id = item + + await sse_stream_writer.send( + { + "event": "message", + "id": event_id, + "data": message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + ) + except Exception as e: + logger.exception(f"Error in replay sender: {e}") + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=replay_sender, + headers=headers, + ) + + try: + await response(request.scope, request.receive, send) + except Exception as e: + logger.exception(f"Error in replay response: {e}") + + except Exception as e: + logger.exception(f"Error replaying events: {e}") + response = self._create_error_response( + f"Error replaying events: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(request.scope, request.receive, send) + @asynccontextmanager async def connect( self, @@ -691,10 +860,24 @@ async def message_router(): target_request_id = str(message.root.id) request_stream_id = target_request_id or GET_STREAM_KEY + + # Store the event if we have an event store, + # regardless of whether a client is connected + # messages will be replayed on the re-connect + event_id = None + if self._event_store: + event_id = await self._event_store.store_event( + request_stream_id, message + ) + logger.debug( + f"Stored event {event_id} for stream {request_stream_id} in message router" + ) + if request_stream_id in self._request_streams: try: + # Send both the message and the event ID await self._request_streams[request_stream_id].send( - message + (message, event_id) ) except ( anyio.BrokenResourceError, From b1932422b4bcd6c54f80bdb88c322a6525356955 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 10:53:54 +0100 Subject: [PATCH 34/45] format --- src/mcp/server/streamableHttp.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index c2357ba1..5a2d2fa2 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -426,7 +426,6 @@ async def sse_writer(): async with sse_stream_writer, request_stream_reader: # Process messages from the request-specific stream async for item in request_stream_reader: - # The message router always sends a tuple of (message, event_id) received_message, event_id = item # Build the event data @@ -437,7 +436,7 @@ async def sse_writer(): ), } - # If an event ID was provided, include it in the SSE stream + # If an event ID was provided, include it if event_id: event_data["id"] = event_id @@ -869,9 +868,7 @@ async def message_router(): event_id = await self._event_store.store_event( request_stream_id, message ) - logger.debug( - f"Stored event {event_id} for stream {request_stream_id} in message router" - ) + logger.debug(f"Stored {event_id} from {request_stream_id}") if request_stream_id in self._request_streams: try: From 611043558716caf1b99b28d5a6a5adbd23d92f25 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 11:03:56 +0100 Subject: [PATCH 35/45] remove print --- src/mcp/server/streamableHttp.py | 3 --- tests/shared/test_streamableHttp.py | 30 ++++++----------------------- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 5a2d2fa2..abed48bb 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -744,9 +744,6 @@ async def replay_sender(): async def send_event( event_id: EventId, message: JSONRPCMessage ) -> None: - print( - "------ REPLAYING EVENT ----------", event_id, message - ) await sse_stream_writer.send( { "event": "message", diff --git a/tests/shared/test_streamableHttp.py b/tests/shared/test_streamableHttp.py index aef0bc96..43bdf388 100644 --- a/tests/shared/test_streamableHttp.py +++ b/tests/shared/test_streamableHttp.py @@ -113,15 +113,12 @@ async def lifespan(app): async with anyio.create_task_group() as tg: task_group = tg - print("Application started, task group initialized!") try: yield finally: - print("Application shutting down, cleaning up resources...") if task_group: tg.cancel_scope.cancel() task_group = None - print("Resources cleaned up successfully.") async def handle_streamable_http(scope, receive, send): request = Request(scope, receive) @@ -148,14 +145,11 @@ async def handle_streamable_http(scope, receive, send): read_stream, write_stream = streams async def run_server(): - try: - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) - except Exception as e: - print(f"Server exception: {e}") + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) if task_group is None: response = Response( @@ -196,10 +190,6 @@ def run_server(port: int, is_json_response_enabled=False) -> None: port: Port to listen on. is_json_response_enabled: If True, use JSON responses instead of SSE streams. """ - print( - f"Starting test server on port {port} with " - f"json_enabled={is_json_response_enabled}" - ) app = create_app(is_json_response_enabled) # Configure server @@ -218,16 +208,12 @@ def run_server(port: int, is_json_response_enabled=False) -> None: # This is important to catch exceptions and prevent test hangs try: - print("Server starting...") server.run() - except Exception as e: - print(f"ERROR: Server failed to run: {e}") + except Exception: import traceback traceback.print_exc() - print("Server shutdown") - # Test fixtures - using same approach as SSE tests @pytest.fixture @@ -273,8 +259,6 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]: # Clean up proc.kill() proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") @pytest.fixture @@ -306,8 +290,6 @@ def json_response_server(json_server_port: int) -> Generator[None, None, None]: # Clean up proc.kill() proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") @pytest.fixture From e0872835cf08c894a20938144682aa1237832ded Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 12:31:44 +0100 Subject: [PATCH 36/45] rename files to follow python naming --- .../simple-streamablehttp/mcp_simple_streamablehttp/server.py | 2 +- src/mcp/client/{streamableHttp.py => streamable_http.py} | 0 src/mcp/server/{streamableHttp.py => streamable_http.py} | 0 .../{test_streamableHttp.py => test_streamable_http.py} | 4 ++-- 4 files changed, 3 insertions(+), 3 deletions(-) rename src/mcp/client/{streamableHttp.py => streamable_http.py} (100%) rename src/mcp/server/{streamableHttp.py => streamable_http.py} (100%) rename tests/shared/{test_streamableHttp.py => test_streamable_http.py} (99%) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index b5faffed..71d4e5a3 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -7,7 +7,7 @@ import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamableHttp import ( +from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport, ) diff --git a/src/mcp/client/streamableHttp.py b/src/mcp/client/streamable_http.py similarity index 100% rename from src/mcp/client/streamableHttp.py rename to src/mcp/client/streamable_http.py diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamable_http.py similarity index 100% rename from src/mcp/server/streamableHttp.py rename to src/mcp/server/streamable_http.py diff --git a/tests/shared/test_streamableHttp.py b/tests/shared/test_streamable_http.py similarity index 99% rename from tests/shared/test_streamableHttp.py rename to tests/shared/test_streamable_http.py index aef0bc96..48af0953 100644 --- a/tests/shared/test_streamableHttp.py +++ b/tests/shared/test_streamable_http.py @@ -24,9 +24,9 @@ from starlette.routing import Mount from mcp.client.session import ClientSession -from mcp.client.streamableHttp import streamablehttp_client +from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server -from mcp.server.streamableHttp import ( +from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, StreamableHTTPServerTransport, From 08247c420d12a4cc1723fd64cd1adf7051f7d233 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 12:39:32 +0100 Subject: [PATCH 37/45] update to use time delta in client --- src/mcp/client/streamable_http.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 14237d2d..b3e65fb9 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -8,6 +8,7 @@ import logging from contextlib import asynccontextmanager +from datetime import timedelta from typing import Any import anyio @@ -37,8 +38,8 @@ async def streamablehttp_client( url: str, headers: dict[str, Any] | None = None, - timeout: float = 30, - sse_read_timeout: float = 60 * 5, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), ): """ Client transport for StreamableHTTP. @@ -71,7 +72,9 @@ async def streamablehttp_client( session_id: str | None = None async with httpx.AsyncClient( - headers=request_headers, timeout=timeout, follow_redirects=True + headers=request_headers, + timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + follow_redirects=True, ) as client: async def post_writer(): @@ -225,7 +228,9 @@ async def get_stream(): "GET", url, headers=get_headers, - timeout=httpx.Timeout(timeout, read=sse_read_timeout), + timeout=httpx.Timeout( + timeout.seconds, read=sse_read_timeout.seconds + ), ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") From 0484dfbfbd8f8fe1784b4db9d30776f356d81fe2 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 13:21:05 +0100 Subject: [PATCH 38/45] refactor --- src/mcp/client/streamable_http.py | 384 ++++++++++++++---------------- 1 file changed, 174 insertions(+), 210 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index b3e65fb9..1e004242 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -58,6 +58,179 @@ async def streamablehttp_client( JSONRPCMessage ](0) + async def get_stream(): + """ + Optional GET stream for server-initiated messages + """ + nonlocal session_id + try: + # Only attempt GET if we have a session ID + if not session_id: + return + + get_headers = request_headers.copy() + get_headers[MCP_SESSION_ID_HEADER] = session_id + + async with aconnect_sse( + client, + "GET", + url, + headers=get_headers, + timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + async for sse in event_source.aiter_sse(): + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"GET message: {message}") + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Error parsing GET message: {exc}") + await read_stream_writer.send(exc) + else: + logger.warning(f"Unknown SSE event from GET: {sse.event}") + except Exception as exc: + # GET stream is optional, so don't propagate errors + logger.debug(f"GET stream error (non-fatal): {exc}") + + async def post_writer(client: httpx.AsyncClient): + nonlocal session_id + try: + async with write_stream_reader: + async for message in write_stream_reader: + # Add session ID to headers if we have one + post_headers = request_headers.copy() + if session_id: + post_headers[MCP_SESSION_ID_HEADER] = session_id + + logger.debug(f"Sending client message: {message}") + + # Handle initial initialization request + is_initialization = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + if ( + isinstance(message.root, JSONRPCNotification) + and message.root.method == "notifications/initialized" + ): + tg.start_soon(get_stream) + + async with client.stream( + "POST", + url, + json=message.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + headers=post_headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + continue + # Check for 404 (session expired/invalid) + if response.status_code == 404: + if isinstance(message.root, JSONRPCRequest): + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=message.root.id, + error=ErrorData( + code=32600, + message="Session terminated", + ), + ) + await read_stream_writer.send( + JSONRPCMessage(jsonrpc_error) + ) + continue + response.raise_for_status() + + # Extract session ID from response headers + if is_initialization: + new_session_id = response.headers.get(MCP_SESSION_ID_HEADER) + if new_session_id: + session_id = new_session_id + logger.info(f"Received session ID: {session_id}") + + # Handle different response types + content_type = response.headers.get("content-type", "").lower() + + if content_type.startswith(CONTENT_TYPE_JSON): + try: + content = await response.aread() + json_message = JSONRPCMessage.model_validate_json( + content + ) + await read_stream_writer.send(json_message) + except Exception as exc: + logger.error(f"Error parsing JSON response: {exc}") + await read_stream_writer.send(exc) + + elif content_type.startswith(CONTENT_TYPE_SSE): + # Parse SSE events from the response + try: + event_source = EventSource(response) + async for sse in event_source.aiter_sse(): + if sse.event == "message": + try: + await read_stream_writer.send( + JSONRPCMessage.model_validate_json( + sse.data + ) + ) + except Exception as exc: + logger.exception("Error parsing message") + await read_stream_writer.send(exc) + else: + logger.warning(f"Unknown event: {sse.event}") + + except Exception as e: + logger.exception("Error reading SSE stream:") + await read_stream_writer.send(e) + + else: + # For 202 Accepted with no body + if response.status_code == 202: + logger.debug("Received 202 Accepted") + continue + + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + await read_stream_writer.send(ValueError(error_msg)) + + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + async def terminate_session(): + """ + Terminate the session by sending a DELETE request. + """ + nonlocal session_id + if not session_id: + return # No session to terminate + + try: + delete_headers = request_headers.copy() + delete_headers[MCP_SESSION_ID_HEADER] = session_id + + response = await client.delete( + url, + headers=delete_headers, + ) + + if response.status_code == 405: + # Server doesn't allow client-initiated termination + logger.debug("Server does not allow session termination") + elif response.status_code != 200: + logger.warning(f"Session termination failed: {response.status_code}") + except Exception as exc: + logger.warning(f"Session termination failed: {exc}") + async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") @@ -67,7 +240,6 @@ async def streamablehttp_client( "Content-Type": CONTENT_TYPE_JSON, **(headers or {}), } - # Track session ID if provided by server session_id: str | None = None @@ -76,215 +248,7 @@ async def streamablehttp_client( timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), follow_redirects=True, ) as client: - - async def post_writer(): - nonlocal session_id - try: - async with write_stream_reader: - async for message in write_stream_reader: - # Add session ID to headers if we have one - post_headers = request_headers.copy() - if session_id: - post_headers[MCP_SESSION_ID_HEADER] = session_id - - logger.debug(f"Sending client message: {message}") - - # Handle initial initialization request - is_initialization = ( - isinstance(message.root, JSONRPCRequest) - and message.root.method == "initialize" - ) - if ( - isinstance(message.root, JSONRPCNotification) - and message.root.method - == "notifications/initialized" - ): - tg.start_soon(get_stream) - - async with client.stream( - "POST", - url, - json=message.model_dump( - by_alias=True, mode="json", exclude_none=True - ), - headers=post_headers, - ) as response: - if response.status_code == 202: - logger.debug("Received 202 Accepted") - continue - # Check for 404 (session expired/invalid) - if response.status_code == 404: - if isinstance(message.root, JSONRPCRequest): - jsonrpc_error = JSONRPCError( - jsonrpc="2.0", - id=message.root.id, - error=ErrorData( - code=32600, - message="Session terminated", - ), - ) - await read_stream_writer.send( - JSONRPCMessage(jsonrpc_error) - ) - continue - response.raise_for_status() - - # Extract session ID from response headers - if is_initialization: - new_session_id = response.headers.get( - MCP_SESSION_ID_HEADER - ) - if new_session_id: - session_id = new_session_id - logger.info( - f"Received session ID: {session_id}" - ) - - # Handle different response types - content_type = response.headers.get( - "content-type", "" - ).lower() - - if content_type.startswith(CONTENT_TYPE_JSON): - try: - content = await response.aread() - json_message = ( - JSONRPCMessage.model_validate_json( - content - ) - ) - await read_stream_writer.send(json_message) - except Exception as exc: - logger.error( - f"Error parsing JSON response: {exc}" - ) - await read_stream_writer.send(exc) - - elif content_type.startswith(CONTENT_TYPE_SSE): - # Parse SSE events from the response - try: - event_source = EventSource(response) - async for sse in event_source.aiter_sse(): - if sse.event == "message": - try: - await read_stream_writer.send( - JSONRPCMessage.model_validate_json( - sse.data - ) - ) - except Exception as exc: - logger.exception( - "Error parsing message" - ) - await read_stream_writer.send( - exc - ) - else: - logger.warning( - f"Unknown event: {sse.event}" - ) - - except Exception as e: - logger.exception( - "Error reading SSE stream:" - ) - await read_stream_writer.send(e) - - else: - # For 202 Accepted with no body - if response.status_code == 202: - logger.debug("Received 202 Accepted") - continue - - error_msg = ( - f"Unexpected content type: {content_type}" - ) - logger.error(error_msg) - await read_stream_writer.send( - ValueError(error_msg) - ) - - except Exception as exc: - logger.error(f"Error in post_writer: {exc}") - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - - async def get_stream(): - """ - Optional GET stream for server-initiated messages - """ - nonlocal session_id - try: - # Only attempt GET if we have a session ID - if not session_id: - return - - get_headers = request_headers.copy() - get_headers[MCP_SESSION_ID_HEADER] = session_id - - async with aconnect_sse( - client, - "GET", - url, - headers=get_headers, - timeout=httpx.Timeout( - timeout.seconds, read=sse_read_timeout.seconds - ), - ) as event_source: - event_source.response.raise_for_status() - logger.debug("GET SSE connection established") - - async for sse in event_source.aiter_sse(): - if sse.event == "message": - try: - message = JSONRPCMessage.model_validate_json( - sse.data - ) - logger.debug(f"GET message: {message}") - await read_stream_writer.send(message) - except Exception as exc: - logger.error( - f"Error parsing GET message: {exc}" - ) - await read_stream_writer.send(exc) - else: - logger.warning( - f"Unknown SSE event from GET: {sse.event}" - ) - except Exception as exc: - # GET stream is optional, so don't propagate errors - logger.debug(f"GET stream error (non-fatal): {exc}") - - tg.start_soon(post_writer) - - async def terminate_session(): - """ - Terminate the session by sending a DELETE request. - """ - nonlocal session_id - if not session_id: - return # No session to terminate - - try: - delete_headers = request_headers.copy() - delete_headers[MCP_SESSION_ID_HEADER] = session_id - - response = await client.delete( - url, - headers=delete_headers, - ) - - if response.status_code == 405: - # Server doesn't allow client-initiated termination - logger.debug("Server does not allow session termination") - elif response.status_code != 200: - logger.warning( - f"Session termination failed: {response.status_code}" - ) - except Exception as exc: - logger.warning(f"Session termination failed: {exc}") - + tg.start_soon(post_writer, client) try: yield read_stream, write_stream, terminate_session finally: From 5757f6cbcfc9798ef805e2992e4d252f16ea0910 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 14:47:49 +0100 Subject: [PATCH 39/45] small fixes --- .../mcp_simple_streamablehttp/event_store.py | 19 ++++++++++++------- src/mcp/server/streamable_http.py | 10 +++++----- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py index a887b97a..3286d4ee 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -8,9 +8,10 @@ import logging import time from collections.abc import Awaitable, Callable +from operator import itemgetter from uuid import uuid4 -from mcp.server.streamableHttp import EventId, EventStore, StreamId +from mcp.server.streamable_http import EventId, EventStore, StreamId from mcp.types import JSONRPCMessage logger = logging.getLogger(__name__) @@ -54,15 +55,19 @@ async def replay_events_after( stream_id, _, last_timestamp = self.events[last_event_id] # Find all events for this stream after the last event + events_sorted = sorted( + [ + (event_id, message, timestamp) + for event_id, (sid, message, timestamp) in self.events.items() + if sid == stream_id and timestamp > last_timestamp + ], + key=itemgetter(2), + ) + events_to_replay = [ - (event_id, message) - for event_id, (sid, message, timestamp) in self.events.items() - if sid == stream_id and timestamp > last_timestamp + (event_id, message) for event_id, message, _ in events_sorted ] - # Sort by timestamp to ensure chronological order - events_to_replay.sort(key=lambda x: self.events[x[0]][2]) - logger.debug(f"Found {len(events_to_replay)} events to replay") logger.debug( f"Events to replay: {[event_id for event_id, _ in events_to_replay]}" diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index abed48bb..67f28ae2 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -373,8 +373,7 @@ async def _handle_post_request( response_message = None # Use similar approach to SSE writer for consistency - async for item in request_stream_reader: - received_message, _ = item # Extract message, ignore event_id + async for received_message, _ in request_stream_reader: # If it's a response, this is what we're waiting for if isinstance( received_message.root, JSONRPCResponse | JSONRPCError @@ -425,9 +424,10 @@ async def sse_writer(): try: async with sse_stream_writer, request_stream_reader: # Process messages from the request-specific stream - async for item in request_stream_reader: - received_message, event_id = item - + async for ( + received_message, + event_id, + ) in request_stream_reader: # Build the event data event_data = { "event": "message", From ee28ad83fdf15e547c07ec2f5b75fcd98f995cd9 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 16:31:22 +0100 Subject: [PATCH 40/45] improve event store example implementation --- .../mcp_simple_streamablehttp/event_store.py | 111 +++++++++++------- src/mcp/server/streamable_http.py | 87 +++++++------- 2 files changed, 115 insertions(+), 83 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py index 3286d4ee..28c58149 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -6,75 +6,100 @@ """ import logging -import time -from collections.abc import Awaitable, Callable -from operator import itemgetter +from collections import deque +from dataclasses import dataclass from uuid import uuid4 -from mcp.server.streamable_http import EventId, EventStore, StreamId +from mcp.server.streamable_http import ( + EventCallback, + EventId, + EventMessage, + EventStore, + StreamId, +) from mcp.types import JSONRPCMessage logger = logging.getLogger(__name__) +@dataclass +class EventEntry: + """ + Represents an event entry in the event store. + """ + + event_id: EventId + stream_id: StreamId + message: JSONRPCMessage + + class InMemoryEventStore(EventStore): """ Simple in-memory implementation of the EventStore interface for resumability. This is primarily intended for examples and testing, not for production use where a persistent storage solution would be more appropriate. + + This implementation keeps only the last N events per stream for memory efficiency. """ - def __init__(self): - self.events: dict[ - str, tuple[str, JSONRPCMessage, float] - ] = {} # event_id -> (stream_id, message, timestamp) + def __init__(self, max_events_per_stream: int = 100): + """Initialize the event store. + + Args: + max_events_per_stream: Maximum number of events to keep per stream + """ + self.max_events_per_stream = max_events_per_stream + # for maintaining last N events per stream + self.streams: dict[StreamId, deque[EventEntry]] = {} + # event_id -> EventEntry for quick lookup + self.event_index: dict[EventId, EventEntry] = {} async def store_event( self, stream_id: StreamId, message: JSONRPCMessage ) -> EventId: """Stores an event with a generated event ID.""" event_id = str(uuid4()) - self.events[event_id] = (stream_id, message, time.time()) + event_entry = EventEntry( + event_id=event_id, stream_id=stream_id, message=message + ) + + # Get or create deque for this stream + if stream_id not in self.streams: + self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) + + # If deque is full, the oldest event will be automatically removed + # We need to remove it from the event_index as well + if len(self.streams[stream_id]) == self.max_events_per_stream: + oldest_event = self.streams[stream_id][0] + self.event_index.pop(oldest_event.event_id, None) + + # Add new event + self.streams[stream_id].append(event_entry) + self.event_index[event_id] = event_entry + return event_id async def replay_events_after( self, last_event_id: EventId, - send_callback: Callable[[EventId, JSONRPCMessage], Awaitable[None]], - ) -> StreamId: + send_callback: EventCallback, + ) -> StreamId | None: """Replays events that occurred after the specified event ID.""" - logger.debug(f"Attempting to replay events after {last_event_id}") - logger.debug(f"Total events in store: {len(self.events)}") - logger.debug(f"Event IDs in store: {list(self.events.keys())}") - - if not last_event_id or last_event_id not in self.events: + if last_event_id not in self.event_index: logger.warning(f"Event ID {last_event_id} not found in store") - return "" - - # Get the stream ID and timestamp from the last event - stream_id, _, last_timestamp = self.events[last_event_id] - - # Find all events for this stream after the last event - events_sorted = sorted( - [ - (event_id, message, timestamp) - for event_id, (sid, message, timestamp) in self.events.items() - if sid == stream_id and timestamp > last_timestamp - ], - key=itemgetter(2), - ) - - events_to_replay = [ - (event_id, message) for event_id, message, _ in events_sorted - ] - - logger.debug(f"Found {len(events_to_replay)} events to replay") - logger.debug( - f"Events to replay: {[event_id for event_id, _ in events_to_replay]}" - ) - - # Send all events in order - for event_id, message in events_to_replay: - await send_callback(event_id, message) + return None + + # Get the stream and find events after the last one + last_event = self.event_index[last_event_id] + stream_id = last_event.stream_id + stream_events = self.streams.get(last_event.stream_id, deque()) + + # Events in deque are already in chronological order + found_last = False + for event in stream_events: + if found_last: + await send_callback(EventMessage(event.message, event.event_id)) + elif event.event_id == last_event_id: + found_last = True return stream_id diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 67f28ae2..cf1ddd34 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -63,6 +63,22 @@ EventId = str +class EventMessage: + """ + A JSONRPCMessage with an optional event ID for stream resumability. + """ + + message: JSONRPCMessage + event_id: str | None + + def __init__(self, message: JSONRPCMessage, event_id: str | None = None): + self.message = message + self.event_id = event_id + + +EventCallback = Callable[[EventMessage], Awaitable[None]] + + class EventStore(ABC): """ Interface for resumability support via event storage. @@ -88,8 +104,8 @@ async def store_event( async def replay_events_after( self, last_event_id: EventId, - send_callback: Callable[[EventId, JSONRPCMessage], Awaitable[None]], - ) -> StreamId: + send_callback: EventCallback, + ) -> StreamId | None: """ Replays events that occurred after the specified event ID. @@ -149,7 +165,7 @@ def __init__( self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store self._request_streams: dict[ - RequestId, MemoryObjectSendStream[tuple[JSONRPCMessage, str | None]] + RequestId, MemoryObjectSendStream[EventMessage] ] = {} self._terminated = False @@ -358,7 +374,7 @@ async def _handle_post_request( request_id = str(message.root.id) # Create promise stream for getting response request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[tuple[JSONRPCMessage, str | None]](0) + anyio.create_memory_object_stream[EventMessage](0) ) # Register this stream for the request ID @@ -373,16 +389,18 @@ async def _handle_post_request( response_message = None # Use similar approach to SSE writer for consistency - async for received_message, _ in request_stream_reader: + async for event_message in request_stream_reader: # If it's a response, this is what we're waiting for if isinstance( - received_message.root, JSONRPCResponse | JSONRPCError + event_message.message.root, JSONRPCResponse | JSONRPCError ): - response_message = received_message + response_message = event_message.message break # For notifications and request, keep waiting else: - logger.debug(f"received: {received_message.root.method}") + logger.debug( + f"received: {event_message.message.root.method}" + ) # At this point we should have a response if response_message: @@ -424,27 +442,24 @@ async def sse_writer(): try: async with sse_stream_writer, request_stream_reader: # Process messages from the request-specific stream - async for ( - received_message, - event_id, - ) in request_stream_reader: + async for event_message in request_stream_reader: # Build the event data event_data = { "event": "message", - "data": received_message.model_dump_json( + "data": event_message.message.model_dump_json( by_alias=True, exclude_none=True ), } # If an event ID was provided, include it - if event_id: - event_data["id"] = event_id + if event_message.event_id: + event_data["id"] = event_message.event_id await sse_stream_writer.send(event_data) # If response, remove from pending streams and close if isinstance( - received_message.root, + event_message.message.root, JSONRPCResponse | JSONRPCError, ): if request_id: @@ -563,9 +578,7 @@ async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages standalone_stream_writer, standalone_stream_reader = ( - anyio.create_memory_object_stream[ - tuple[JSONRPCMessage, str | None] - ](0) + anyio.create_memory_object_stream[EventMessage](0) ) # Register this stream using the special key @@ -573,10 +586,7 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for item in standalone_stream_reader: - # The message router always sends a tuple of (message, event_id) - received_message, event_id = item - + async for event_message in standalone_stream_reader: # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) @@ -585,14 +595,14 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = { "event": "message", - "data": received_message.model_dump_json( + "data": event_message.message.model_dump_json( by_alias=True, exclude_none=True ), } # If an event ID was provided, include it in the SSE stream - if event_id: - event_data["id"] = event_id + if event_message.event_id: + event_data["id"] = event_message.event_id await sse_stream_writer.send(event_data) except Exception as e: @@ -741,14 +751,12 @@ async def replay_sender(): try: async with sse_stream_writer: # Define an async callback for sending events - async def send_event( - event_id: EventId, message: JSONRPCMessage - ) -> None: + async def send_event(event_message: EventMessage) -> None: await sse_stream_writer.send( { "event": "message", - "id": event_id, - "data": message.model_dump_json( + "id": event_message.event_id, + "data": event_message.message.model_dump_json( by_alias=True, exclude_none=True ), } @@ -762,22 +770,21 @@ async def send_event( # If stream ID not in mapping, create it if stream_id and stream_id not in self._request_streams: msg_writer, msg_reader = anyio.create_memory_object_stream[ - tuple[JSONRPCMessage, str | None] + EventMessage ](0) self._request_streams[stream_id] = msg_writer # Forward messages to SSE async with msg_reader: - async for item in msg_reader: - message, event_id = item - + async for event_message in msg_reader: + event_data = event_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await sse_stream_writer.send( { "event": "message", - "id": event_id, - "data": message.model_dump_json( - by_alias=True, exclude_none=True - ), + "id": event_message.event_id, + "data": event_data, } ) except Exception as e: @@ -871,7 +878,7 @@ async def message_router(): try: # Send both the message and the event ID await self._request_streams[request_stream_id].send( - (message, event_id) + EventMessage(message, event_id) ) except ( anyio.BrokenResourceError, From 5dbddeb16c4935deec61e8e955ada5f5a3a5f6e6 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 25 Apr 2025 17:04:20 +0100 Subject: [PATCH 41/45] refactor _create_event_data --- src/mcp/server/streamable_http.py | 78 +++++++++++-------------------- 1 file changed, 28 insertions(+), 50 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index cf1ddd34..36c4b636 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -13,8 +13,8 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager +from dataclasses import dataclass from http import HTTPStatus -from typing import Any import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -63,17 +63,14 @@ EventId = str +@dataclass class EventMessage: """ A JSONRPCMessage with an optional event ID for stream resumability. """ message: JSONRPCMessage - event_id: str | None - - def __init__(self, message: JSONRPCMessage, event_id: str | None = None): - self.message = message - self.event_id = event_id + event_id: str | None = None EventCallback = Callable[[EventMessage], Awaitable[None]] @@ -226,6 +223,21 @@ def _get_session_id(self, request: Request) -> str | None: """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: + """Create event data dictionary from an EventMessage.""" + event_data = { + "event": "message", + "data": event_message.message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + # If an event ID was provided, include it + if event_message.event_id: + event_data["id"] = event_message.event_id + + return event_data + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests""" request = Request(scope, receive) @@ -434,7 +446,7 @@ async def _handle_post_request( else: # Create SSE stream sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, Any]](0) + anyio.create_memory_object_stream[dict[str, str]](0) ) async def sse_writer(): @@ -444,17 +456,7 @@ async def sse_writer(): # Process messages from the request-specific stream async for event_message in request_stream_reader: # Build the event data - event_data = { - "event": "message", - "data": event_message.message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - - # If an event ID was provided, include it - if event_message.event_id: - event_data["id"] = event_message.event_id - + event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) # If response, remove from pending streams and close @@ -571,7 +573,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, Any] + dict[str, str] ](0) async def standalone_sse_writer(): @@ -593,17 +595,7 @@ async def standalone_sse_writer(): # We should NOT receive JSONRPCResponse # Send the message via SSE - event_data = { - "event": "message", - "data": event_message.message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - - # If an event ID was provided, include it in the SSE stream - if event_message.event_id: - event_data["id"] = event_message.event_id - + event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) except Exception as e: logger.exception(f"Error in standalone SSE writer: {e}") @@ -744,7 +736,7 @@ async def _replay_events( # Create SSE stream for replay sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, Any] + dict[str, str] ](0) async def replay_sender(): @@ -752,15 +744,8 @@ async def replay_sender(): async with sse_stream_writer: # Define an async callback for sending events async def send_event(event_message: EventMessage) -> None: - await sse_stream_writer.send( - { - "event": "message", - "id": event_message.event_id, - "data": event_message.message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - ) + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) # Replay past events and get the stream ID stream_id = await event_store.replay_events_after( @@ -777,16 +762,9 @@ async def send_event(event_message: EventMessage) -> None: # Forward messages to SSE async with msg_reader: async for event_message in msg_reader: - event_data = event_message.message.model_dump_json( - by_alias=True, exclude_none=True - ) - await sse_stream_writer.send( - { - "event": "message", - "id": event_message.event_id, - "data": event_data, - } - ) + event_data = self._create_event_data(event_message) + + await sse_stream_writer.send(event_data) except Exception as e: logger.exception(f"Error in replay sender: {e}") From 8650c77269e09adb04a894df378430db8bb924b0 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 27 Apr 2025 20:56:29 +0100 Subject: [PATCH 42/45] add session message --- src/mcp/client/__main__.py | 6 +-- src/mcp/client/session.py | 5 +- src/mcp/client/sse.py | 20 ++++--- src/mcp/client/stdio/__init__.py | 18 ++++--- src/mcp/client/streamable_http.py | 23 +++++--- src/mcp/client/websocket.py | 20 +++---- src/mcp/server/lowlevel/server.py | 5 +- src/mcp/server/session.py | 5 +- src/mcp/server/sse.py | 24 ++++----- src/mcp/server/stdio.py | 18 ++++--- src/mcp/server/streamable_http.py | 27 ++++++---- src/mcp/server/websocket.py | 18 ++++--- src/mcp/shared/memory.py | 10 ++-- src/mcp/shared/message.py | 35 +++++++++++++ src/mcp/shared/session.py | 33 +++++++----- tests/client/test_session.py | 73 +++++++++++++++----------- tests/client/test_stdio.py | 6 ++- tests/issues/test_192_request_id.py | 15 ++++-- tests/server/test_lifespan.py | 81 +++++++++++++++++------------ tests/server/test_session.py | 6 +-- tests/server/test_stdio.py | 6 ++- 21 files changed, 283 insertions(+), 171 deletions(-) create mode 100644 src/mcp/shared/message.py diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 84e15bd5..2ec68e56 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -11,8 +11,8 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import JSONRPCMessage if not sys.warnoptions: import warnings @@ -36,8 +36,8 @@ async def message_handler( async def run_session( - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], client_info: types.Implementation | None = None, ): async with ClientSession( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index e29797d1..32fb4cbe 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -7,6 +7,7 @@ import mcp.types as types from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -92,8 +93,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4f6241a7..ff04d2f9 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,6 +10,7 @@ from httpx_sse import aconnect_sse import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -31,11 +32,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -97,7 +98,8 @@ async def sse_reader( await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) case _: logger.warning( f"Unknown SSE event: {sse.event}" @@ -111,11 +113,13 @@ async def sse_reader( async def post_writer(endpoint_url: str): try: async with write_stream_reader: - async for message in write_stream_reader: - logger.debug(f"Sending client message: {message}") + async for session_message in write_stream_reader: + logger.debug( + f"Sending client message: {session_message}" + ) response = await client.post( endpoint_url, - json=message.model_dump( + json=session_message.message.model_dump( by_alias=True, mode="json", exclude_none=True, diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 83de57a2..e8be5aff 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field import mcp.types as types +from mcp.shared.message import SessionMessage from .win32 import ( create_windows_process, @@ -98,11 +99,11 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -143,7 +144,8 @@ async def stdout_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() @@ -152,8 +154,10 @@ async def stdin_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - json = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await process.stdin.send( (json + "\n").encode( encoding=server.encoding, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 1e004242..7a8887cd 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -15,6 +15,7 @@ import httpx from httpx_sse import EventSource, aconnect_sse +from mcp.shared.message import SessionMessage from mcp.types import ( ErrorData, JSONRPCError, @@ -52,10 +53,10 @@ async def streamablehttp_client( """ read_stream_writer, read_stream = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](0) async def get_stream(): @@ -86,7 +87,8 @@ async def get_stream(): try: message = JSONRPCMessage.model_validate_json(sse.data) logger.debug(f"GET message: {message}") - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except Exception as exc: logger.error(f"Error parsing GET message: {exc}") await read_stream_writer.send(exc) @@ -100,7 +102,8 @@ async def post_writer(client: httpx.AsyncClient): nonlocal session_id try: async with write_stream_reader: - async for message in write_stream_reader: + async for session_message in write_stream_reader: + message = session_message.message # Add session ID to headers if we have one post_headers = request_headers.copy() if session_id: @@ -141,9 +144,10 @@ async def post_writer(client: httpx.AsyncClient): message="Session terminated", ), ) - await read_stream_writer.send( + session_message = SessionMessage( JSONRPCMessage(jsonrpc_error) ) + await read_stream_writer.send(session_message) continue response.raise_for_status() @@ -163,7 +167,8 @@ async def post_writer(client: httpx.AsyncClient): json_message = JSONRPCMessage.model_validate_json( content ) - await read_stream_writer.send(json_message) + session_message = SessionMessage(json_message) + await read_stream_writer.send(session_message) except Exception as exc: logger.error(f"Error parsing JSON response: {exc}") await read_stream_writer.send(exc) @@ -175,11 +180,15 @@ async def post_writer(client: httpx.AsyncClient): async for sse in event_source.aiter_sse(): if sse.event == "message": try: - await read_stream_writer.send( + message = ( JSONRPCMessage.model_validate_json( sse.data ) ) + session_message = SessionMessage(message) + await read_stream_writer.send( + session_message + ) except Exception as exc: logger.exception("Error parsing message") await read_stream_writer.send(exc) diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 2c2ed38b..ac542fb3 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -10,6 +10,7 @@ from websockets.typing import Subprotocol import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -19,8 +20,8 @@ async def websocket_client( url: str, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - MemoryObjectSendStream[types.JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], ], None, ]: @@ -39,10 +40,10 @@ async def websocket_client( # Create two in-memory streams: # - One for incoming messages (read_stream, written by ws_reader) # - One for outgoing messages (write_stream, read by ws_writer) - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -59,7 +60,8 @@ async def ws_reader(): async for raw_text in ws: try: message = types.JSONRPCMessage.model_validate_json(raw_text) - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception await read_stream_writer.send(exc) @@ -70,9 +72,9 @@ async def ws_writer(): sends them to the server. """ async with write_stream_reader: - async for message in write_stream_reader: + async for session_message in write_stream_reader: # Convert to a dict, then to JSON - msg_dict = message.model_dump( + msg_dict = session_message.message.model_dump( by_alias=True, mode="json", exclude_none=True ) await ws.send(json.dumps(msg_dict)) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b47f5305..a31d52a6 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -84,6 +84,7 @@ async def main(): from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder logger = logging.getLogger(__name__) @@ -471,8 +472,8 @@ async def handler(req: types.CompleteRequest): async def run( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 07e5a315..7a5ace5e 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.shared.session import ( BaseSession, RequestResponder, @@ -82,8 +83,8 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, standalone_mode: bool = False, ) -> None: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d051c25b..c781c64d 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -46,6 +46,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -63,9 +64,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[ - UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] - ] + _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] def __init__(self, endpoint: str) -> None: """ @@ -85,11 +84,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -109,12 +108,12 @@ async def sse_writer(): await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) logger.debug(f"Sent endpoint event: {session_uri}") - async for message in write_stream_reader: - logger.debug(f"Sending message via SSE: {message}") + async for session_message in write_stream_reader: + logger.debug(f"Sending message via SSE: {session_message}") await sse_stream_writer.send( { "event": "message", - "data": message.model_dump_json( + "data": session_message.message.model_dump_json( by_alias=True, exclude_none=True ), } @@ -169,7 +168,8 @@ async def handle_post_message( await writer.send(err) return - logger.debug(f"Sending message to writer: {message}") + session_message = SessionMessage(message) + logger.debug(f"Sending session message to writer: {session_message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(message) + await writer.send(session_message) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 0e0e4912..f0bbe5a3 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -27,6 +27,7 @@ async def run_server(): from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types +from mcp.shared.message import SessionMessage @asynccontextmanager @@ -47,11 +48,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -66,15 +67,18 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() async def stdout_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - json = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await stdout.write(json + "\n") await stdout.flush() except anyio.ClosedResourceError: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 36c4b636..8929a50f 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,6 +24,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.shared.message import SessionMessage from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, @@ -125,10 +126,10 @@ class StreamableHTTPServerTransport: """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = ( + _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( None ) - _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None + _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None def __init__( self, @@ -378,7 +379,8 @@ async def _handle_post_request( await response(scope, receive, send) # Process the message after sending the response - await writer.send(message) + session_message = SessionMessage(message) + await writer.send(session_message) return @@ -394,7 +396,8 @@ async def _handle_post_request( if self.is_json_response_enabled: # Process the message - await writer.send(message) + session_message = SessionMessage(message) + await writer.send(session_message) try: # Process messages from the request-specific stream # We need to collect all messages until we get a response @@ -500,7 +503,8 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - await writer.send(message) + session_message = SessionMessage(message) + await writer.send(session_message) except Exception: logger.exception("SSE response error") # Clean up the request stream if something goes wrong @@ -516,7 +520,7 @@ async def sse_writer(): ) await response(scope, receive, send) if writer: - await writer.send(err) + await writer.send(Exception(err)) return async def _handle_get_request(self, request: Request, send: Send) -> None: @@ -794,8 +798,8 @@ async def connect( self, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], ], None, ]: @@ -808,10 +812,10 @@ async def connect( # Create the memory streams for this connection read_stream_writer, read_stream = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](0) # Store the streams @@ -823,8 +827,9 @@ async def connect( # Create a message router that distributes messages to request streams async def message_router(): try: - async for message in write_stream_reader: + async for session_message in write_stream_reader: # Determine which request stream(s) should receive this message + message = session_message.message target_request_id = None if isinstance( message.root, JSONRPCNotification | JSONRPCRequest diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index aee855cf..9dc3f2a2 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -8,6 +8,7 @@ from starlette.websockets import WebSocket import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -22,11 +23,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -41,15 +42,18 @@ async def ws_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(client_message) + session_message = SessionMessage(client_message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await websocket.close() async def ws_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - obj = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + obj = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await websocket.send_text(obj) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index abf87a3a..b53f8dd6 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -19,11 +19,11 @@ SamplingFnT, ) from mcp.server import Server -from mcp.types import JSONRPCMessage +from mcp.shared.message import SessionMessage MessageStream = tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], ] @@ -40,10 +40,10 @@ async def create_client_server_memory_streams() -> ( """ # Create streams for both directions server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](1) client_streams = (server_to_client_receive, client_to_server_send) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py new file mode 100644 index 00000000..c9341c36 --- /dev/null +++ b/src/mcp/shared/message.py @@ -0,0 +1,35 @@ +""" +Message wrapper with metadata support. + +This module defines a wrapper type that combines JSONRPCMessage with metadata +to support transport-specific features like resumability. +""" + +from dataclasses import dataclass + +from mcp.types import JSONRPCMessage, RequestId + + +@dataclass +class ClientMessageMetadata: + """Metadata specific to client messages.""" + + resumption_token: str | None = None + + +@dataclass +class ServerMessageMetadata: + """Metadata specific to server messages.""" + + related_request_id: RequestId | None = None + + +MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None + + +@dataclass +class SessionMessage: + """A message with specific metadata for transport-specific features.""" + + message: JSONRPCMessage + metadata: MessageMetadata = None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3a01cb04..d6d7ee56 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,6 +12,7 @@ from typing_extensions import Self from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage from mcp.types import ( CancelledNotification, ClientNotification, @@ -172,8 +173,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], receive_request_type: type[ReceiveRequestT], receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out @@ -241,7 +242,8 @@ async def send_request( # TODO: Support progress callbacks - await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_request)) + await self._write_stream.send(session_message) try: with anyio.fail_after( @@ -293,14 +295,16 @@ async def send_notification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification)) + await self._write_stream.send(session_message) async def _send_response( self, request_id: RequestId, response: SendResultT | ErrorData ) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - await self._write_stream.send(JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + await self._write_stream.send(session_message) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -309,7 +313,8 @@ async def _send_response( by_alias=True, mode="json", exclude_none=True ), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + await self._write_stream.send(session_message) async def _receive_loop(self) -> None: async with ( @@ -319,15 +324,15 @@ async def _receive_loop(self) -> None: async for message in self._read_stream: if isinstance(message, Exception): await self._handle_incoming(message) - elif isinstance(message.root, JSONRPCRequest): + elif isinstance(message.message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( - message.root.model_dump( + message.message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) responder = RequestResponder( - request_id=message.root.id, + request_id=message.message.root.id, request_meta=validated_request.root.params.meta if validated_request.root.params else None, @@ -342,10 +347,10 @@ async def _receive_loop(self) -> None: if not responder._completed: # type: ignore[reportPrivateUsage] await self._handle_incoming(responder) - elif isinstance(message.root, JSONRPCNotification): + elif isinstance(message.message.root, JSONRPCNotification): try: notification = self._receive_notification_type.model_validate( - message.root.model_dump( + message.message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) @@ -361,12 +366,12 @@ async def _receive_loop(self) -> None: # For other validation errors, log and continue logging.warning( f"Failed to validate notification: {e}. " - f"Message was: {message.root}" + f"Message was: {message.message.root}" ) else: # Response or error - stream = self._response_streams.pop(message.root.id, None) + stream = self._response_streams.pop(message.message.root.id, None) if stream: - await stream.send(message.root) + await stream.send(message.message.root) else: await self._handle_incoming( RuntimeError( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 543ebb2f..6abcf70c 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -3,6 +3,7 @@ import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -24,10 +25,10 @@ @pytest.mark.anyio async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) initialized_notification = None @@ -35,7 +36,8 @@ async def test_client_session_initialize(): async def mock_server(): nonlocal initialized_notification - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -59,17 +61,20 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) - jsonrpc_notification = await client_to_server_receive.receive() + session_notification = await client_to_server_receive.receive() + jsonrpc_notification = session_notification.message assert isinstance(jsonrpc_notification.root, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( jsonrpc_notification.model_dump( @@ -116,10 +121,10 @@ async def message_handler( @pytest.mark.anyio async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) custom_client_info = Implementation(name="test-client", version="1.2.3") @@ -128,7 +133,8 @@ async def test_client_session_custom_client_info(): async def mock_server(): nonlocal received_client_info - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -146,13 +152,15 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) @@ -181,10 +189,10 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_default_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) received_client_info = None @@ -192,7 +200,8 @@ async def test_client_session_default_client_info(): async def mock_server(): nonlocal received_client_info - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -210,13 +219,15 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd..523ba199 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -3,6 +3,7 @@ import pytest from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore @@ -22,7 +23,8 @@ async def test_stdio_client(): async with write_stream: for message in messages: - await write_stream.send(message) + session_message = SessionMessage(message) + await write_stream.send(session_message) read_messages = [] async with read_stream: @@ -30,7 +32,7 @@ async def test_stdio_client(): if isinstance(message, Exception): raise message - read_messages.append(message) + read_messages.append(message.message) if len(read_messages) == 2: break diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index 00e18789..cf5eb608 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -3,6 +3,7 @@ from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, @@ -64,8 +65,10 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=init_req)) - await server_reader.receive() # Get init response but don't need to check it + await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) + response = ( + await server_reader.receive() + ) # Get init response but don't need to check it # Send initialized notification initialized_notification = JSONRPCNotification( @@ -73,21 +76,23 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=initialized_notification)) + await client_writer.send( + SessionMessage(JSONRPCMessage(root=initialized_notification)) + ) # Send ping request with custom ID ping_request = JSONRPCRequest( id=custom_request_id, method="ping", params={}, jsonrpc="2.0" ) - await client_writer.send(JSONRPCMessage(root=ping_request)) + await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) # Read response response = await server_reader.receive() # Verify response ID matches request ID assert ( - response.root.id == custom_request_id + response.message.root.id == custom_request_id ), "Response ID should match request ID" # Cancel server task diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 309a44b8..a3ff59bc 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -10,6 +10,7 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.types import ( ClientCapabilities, Implementation, @@ -82,41 +83,49 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) ) ) ) response = await receive_stream2.receive() + response = response.message # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", + SessionMessage( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) ) ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) ) ) ) # Get response and verify response = await receive_stream2.receive() + response = response.message assert response.root.result["content"][0]["text"] == "true" # Cancel server task @@ -178,41 +187,49 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) ) ) ) response = await receive_stream2.receive() + response = response.message # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", + SessionMessage( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) ) ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) ) ) ) # Get response and verify response = await receive_stream2.receive() + response = response.message assert response.root.result["content"][0]["text"] == "true" # Cancel server task diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 561a94b6..f2f03358 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -7,11 +7,11 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, InitializedNotification, - JSONRPCMessage, PromptsCapability, ResourcesCapability, ServerCapabilities, @@ -21,10 +21,10 @@ @pytest.mark.anyio async def test_server_session_initialize(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) # Create a message handler to catch exceptions diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 85c5bf21..c546a716 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -4,6 +4,7 @@ import pytest from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -29,7 +30,7 @@ async def test_stdio_server(): async for message in read_stream: if isinstance(message, Exception): raise message - received_messages.append(message) + received_messages.append(message.message) if len(received_messages) == 2: break @@ -50,7 +51,8 @@ async def test_stdio_server(): async with write_stream: for response in responses: - await write_stream.send(response) + session_message = SessionMessage(response) + await write_stream.send(session_message) stdout.seek(0) output_lines = stdout.readlines() From 80780dc3de3f7c3ab24951fe45d315528e93124e Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 27 Apr 2025 22:27:27 +0100 Subject: [PATCH 43/45] use metadata from SessionMessage to propagate related_request_id --- src/mcp/server/streamable_http.py | 15 ++++++++++----- src/mcp/shared/session.py | 19 +++++-------------- tests/client/test_logging_callback.py | 4 ---- 3 files changed, 15 insertions(+), 23 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8929a50f..5b18ba37 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,7 +24,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send -from mcp.shared.message import SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, @@ -836,12 +836,17 @@ async def message_router(): ): # Extract related_request_id from meta if it exists if ( - (params := getattr(message.root, "params", None)) - and (meta := params.get("_meta")) - and (related_id := meta.get("related_request_id")) + session_message.metadata is not None + and isinstance( + session_message.metadata, + ServerMessageMetadata, + ) + and session_message.metadata.related_request_id is not None ): - target_request_id = str(related_id) + target_request_id = str( + session_message.metadata.related_request_id + ) else: target_request_id = str(message.root.id) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index d6d7ee56..2e21df22 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,7 +12,7 @@ from typing_extensions import Self from mcp.shared.exceptions import McpError -from mcp.shared.message import SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.types import ( CancelledNotification, ClientNotification, @@ -24,7 +24,6 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, - NotificationParams, RequestParams, ServerNotification, ServerRequest, @@ -280,22 +279,14 @@ async def send_notification( """ # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. - if related_request_id is not None and notification.root.params is not None: - # Create meta if it doesn't exist - if notification.root.params.meta is None: - meta_dict = {"related_request_id": related_request_id} - - else: - meta_dict = notification.root.params.meta.model_dump( - by_alias=True, mode="json", exclude_none=True - ) - meta_dict["related_request_id"] = related_request_id - notification.root.params.meta = NotificationParams.Meta(**meta_dict) jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification)) + session_message = SessionMessage( + message=JSONRPCMessage(jsonrpc_notification), + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) await self._write_stream.send(session_message) async def _send_response( diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 588fa649..0c9eeb39 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -9,7 +9,6 @@ from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, - NotificationParams, TextContent, ) @@ -80,10 +79,7 @@ async def message_handler( assert log_result.isError is False assert len(logging_collector.log_messages) == 1 # Create meta object with related_request_id added dynamically - meta = NotificationParams.Meta() - setattr(meta, "related_request_id", "2") log = logging_collector.log_messages[0] assert log.level == "info" assert log.logger == "test_logger" assert log.data == "Test log message" - assert log.meta == meta From 901dc988dad40842298704877473910b7b815fed Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 28 Apr 2025 09:51:30 +0100 Subject: [PATCH 44/45] assign server message only when related_request_id is not none --- src/mcp/shared/session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 2e21df22..033c8deb 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -285,7 +285,9 @@ async def send_notification( ) session_message = SessionMessage( message=JSONRPCMessage(jsonrpc_notification), - metadata=ServerMessageMetadata(related_request_id=related_request_id), + metadata=ServerMessageMetadata(related_request_id=related_request_id) + if related_request_id + else None, ) await self._write_stream.send(session_message) From e06e3a596c45a389707951dcf73798b311b260df Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 09:43:07 +0100 Subject: [PATCH 45/45] fixes after merge --- src/mcp/shared/session.py | 4 +++- tests/server/test_lowlevel_tool_annotations.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2b8255e..cbf47be5 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -241,7 +241,9 @@ async def send_request( # TODO: Support progress callbacks - await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request))) + await self._write_stream.send( + SessionMessage(message=JSONRPCMessage(jsonrpc_request)) + ) # request read timeout takes precedence over session read timeout timeout = None diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 47d03ad2..e9eff9ed 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -8,10 +8,10 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( ClientResult, - JSONRPCMessage, ServerNotification, ServerRequest, Tool, @@ -46,10 +46,10 @@ async def list_tools(): ] server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](10) # Message handler for client