Skip to content

Streamable Http - clean up server memory streams #604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 76 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
2b95598
initial streamable http server
ihrpr Apr 20, 2025
3d790f8
add request validation and tests
ihrpr Apr 20, 2025
27bc01e
session management
ihrpr Apr 20, 2025
3c4cf10
terminations of a session
ihrpr Apr 20, 2025
bce74b3
fix cleaning up
ihrpr Apr 20, 2025
2011579
add happy path test
ihrpr Apr 20, 2025
2cebf08
tests
ihrpr Apr 20, 2025
6c9c320
json mode
ihrpr Apr 20, 2025
ede8cde
clean up
ihrpr Apr 21, 2025
2a3bed8
fix example server
ihrpr Apr 21, 2025
0456b1b
return 405 for get stream
ihrpr Apr 21, 2025
97ca48d
speed up tests
ihrpr Apr 21, 2025
f738cbf
stateless implemetation
ihrpr Apr 21, 2025
92d4287
format
ihrpr Apr 21, 2025
aa9f6e5
uv lock
ihrpr Apr 21, 2025
2fba7f3
Merge branch 'ihrpr/streamablehttp-server' into ihrpr/streamablehttp-…
ihrpr Apr 21, 2025
45723ea
simplify readme
ihrpr Apr 21, 2025
6b7a616
clean up
ihrpr Apr 21, 2025
b1be691
get sse
ihrpr Apr 22, 2025
201ec99
uv lock
ihrpr Apr 22, 2025
46ec72d
clean up
ihrpr Apr 22, 2025
1902abb
Merge branch 'ihrpr/streamablehttp-server' into ihrpr/streamablehttp-…
ihrpr Apr 22, 2025
da1df74
Merge branch 'ihrpr/streamablehttp-stateless' into ihrpr/get-sse
ihrpr Apr 22, 2025
c2be5af
streamable http client
ihrpr Apr 23, 2025
9b096dc
add comments to server example where we use related_request_id
ihrpr Apr 23, 2025
bbe79c2
Merge branch 'main' into ihrpr/streamablehttp-server
ihrpr Apr 23, 2025
a0a9c5b
small fixes
ihrpr Apr 23, 2025
a5ac2e0
use mta field for related_request_id
ihrpr Apr 23, 2025
2e615f3
unrelated test and format
ihrpr Apr 23, 2025
110526d
clean up
ihrpr Apr 23, 2025
7ffd5ba
terminate session
ihrpr Apr 23, 2025
029ec56
format
ihrpr Apr 23, 2025
cae32e2
Merge branch 'ihrpr/streamablehttp-server' into ihrpr/streamablehttp-…
ihrpr Apr 25, 2025
58745c7
remove useless sleep
ihrpr Apr 25, 2025
1387929
rename require_initialization to standalone_mode
ihrpr Apr 25, 2025
bccff75
Merge branch 'ihrpr/streamablehttp-stateless' into ihrpr/get-sse
ihrpr Apr 25, 2025
dd007d7
Merge branch 'ihrpr/get-sse' into ihrpr/client
ihrpr Apr 25, 2025
6482120
remove redundant check for initialize and session
ihrpr Apr 25, 2025
9a6da2e
ruff check
ihrpr Apr 25, 2025
b957fad
Merge branch 'ihrpr/get-sse' into ihrpr/client
ihrpr Apr 25, 2025
3f5fd7e
support for resumability - server
ihrpr Apr 25, 2025
b193242
format
ihrpr Apr 25, 2025
6110435
remove print
ihrpr Apr 25, 2025
e087283
rename files to follow python naming
ihrpr Apr 25, 2025
08247c4
update to use time delta in client
ihrpr Apr 25, 2025
0484dfb
refactor
ihrpr Apr 25, 2025
88ff2ba
Merge branch 'ihrpr/client' into ihrpr/resumability-server
ihrpr Apr 25, 2025
5757f6c
small fixes
ihrpr Apr 25, 2025
ee28ad8
improve event store example implementation
ihrpr Apr 25, 2025
5dbddeb
refactor _create_event_data
ihrpr Apr 25, 2025
8650c77
add session message
ihrpr Apr 27, 2025
80780dc
use metadata from SessionMessage to propagate related_request_id
ihrpr Apr 27, 2025
901dc98
assign server message only when related_request_id is not none
ihrpr Apr 28, 2025
6c2f7de
client resumability
ihrpr Apr 28, 2025
a346d6c
refactor client
ihrpr Apr 28, 2025
02f00c4
remove resume_tool
ihrpr Apr 28, 2025
dde8cd5
clean up server memory streams
ihrpr Apr 29, 2025
db24790
remove coometed out code
ihrpr Apr 29, 2025
ff70bd6
Merge branch 'main' into ihrpr/streamablehttp-server
ihrpr May 2, 2025
179fbc8
Merge branch 'ihrpr/streamablehttp-server' into ihrpr/streamablehttp-…
ihrpr May 2, 2025
a979864
Merge branch 'ihrpr/streamablehttp-stateless' into ihrpr/get-sse
ihrpr May 2, 2025
11b7dd9
Merge branch 'ihrpr/get-sse' into ihrpr/client
ihrpr May 2, 2025
67a899c
Merge branch 'ihrpr/client' into ihrpr/resumability-server
ihrpr May 2, 2025
c3e0ff3
Merge branch 'ihrpr/resumability-server' into ihrpr/memory-stream-ite…
ihrpr May 2, 2025
83503a0
Merge branch 'ihrpr/memory-stream-item-type' into ihrpr/use-session-m…
ihrpr May 2, 2025
5090989
Merge branch 'ihrpr/use-session-message-for-related-request' into ihr…
ihrpr May 2, 2025
3f8303a
suggested changes
ihrpr May 2, 2025
e06e3a5
fixes after merge
ihrpr May 2, 2025
f6cea03
Merge branch 'ihrpr/memory-stream-item-type' into ihrpr/use-session-m…
ihrpr May 2, 2025
2bd0a27
Merge branch 'ihrpr/use-session-message-for-related-request' into ihr…
ihrpr May 2, 2025
e73677a
remove resumtion from call_tool API
ihrpr May 2, 2025
dacd294
return type for streamablehttp_client
ihrpr May 2, 2025
d1bd44b
terminate on close instead of callback to terminate
ihrpr May 2, 2025
526ddff
Merge branch 'ihrpr/client-resumability' into ihrpr/server-closing-st…
ihrpr May 2, 2025
f83df9d
Merge branch 'main' into ihrpr/server-closing-streams
ihrpr May 2, 2025
28bd038
suggested changes
ihrpr May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,22 @@ async def handle_streamable_http(scope, receive, send):
)
server_instances[http_transport.mcp_session_id] = http_transport
logger.info(f"Created new transport with session ID: {new_session_id}")
async with http_transport.connect() as streams:
read_stream, write_stream = streams

async def run_server():
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)
async def run_server(task_status=None):
async with http_transport.connect() as streams:
read_stream, write_stream = streams
if task_status:
task_status.started()
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)

if not task_group:
raise RuntimeError("Task group is not initialized")

task_group.start_soon(run_server)
await task_group.start(run_server)

# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ async def run(
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
# When True, the server as stateless deployments where
# When True, the server is stateless and
# clients can perform initialization with any node. The client must still follow
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
Expand Down
124 changes: 80 additions & 44 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
None
)
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None

def __init__(
Expand Down Expand Up @@ -163,7 +165,11 @@ def __init__(
self.is_json_response_enabled = is_json_response_enabled
self._event_store = event_store
self._request_streams: dict[
RequestId, MemoryObjectSendStream[EventMessage]
RequestId,
tuple[
MemoryObjectSendStream[EventMessage],
MemoryObjectReceiveStream[EventMessage],
],
] = {}
self._terminated = False

Expand Down Expand Up @@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:

return event_data

async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
"""Clean up memory streams for a given request ID."""
if request_id in self._request_streams:
try:
# Close the request stream
await self._request_streams[request_id][0].aclose()
await self._request_streams[request_id][1].aclose()
except Exception as e:
logger.debug(f"Error closing memory streams: {e}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exceptions would we expect to see here? I'm slightly worried about swallowing everything, since (I think) aclose() itself shouldn't normally throw anything

finally:
# Remove the request stream from the mapping
self._request_streams.pop(request_id, None)

async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests"""
request = Request(scope, receive)
Expand Down Expand Up @@ -386,13 +405,11 @@ async def _handle_post_request(

# Extract the request ID outside the try block for proper scope
request_id = str(message.root.id)
# Create promise stream for getting response
request_stream_writer, request_stream_reader = (
anyio.create_memory_object_stream[EventMessage](0)
)

# Register this stream for the request ID
self._request_streams[request_id] = request_stream_writer
self._request_streams[request_id] = anyio.create_memory_object_stream[
EventMessage
](0)
request_stream_reader = self._request_streams[request_id][1]

if self.is_json_response_enabled:
# Process the message
Expand Down Expand Up @@ -441,11 +458,7 @@ async def _handle_post_request(
)
await response(scope, receive, send)
finally:
# Clean up the request stream
if request_id in self._request_streams:
self._request_streams.pop(request_id, None)
await request_stream_reader.aclose()
await request_stream_writer.aclose()
await self._clean_up_memory_streams(request_id)
else:
# Create SSE stream
sse_stream_writer, sse_stream_reader = (
Expand All @@ -467,16 +480,12 @@ async def sse_writer():
event_message.message.root,
JSONRPCResponse | JSONRPCError,
):
if request_id:
self._request_streams.pop(request_id, None)
break
except Exception as e:
logger.exception(f"Error in SSE writer: {e}")
finally:
logger.debug("Closing SSE writer")
# Clean up the request-specific streams
if request_id and request_id in self._request_streams:
self._request_streams.pop(request_id, None)
await self._clean_up_memory_streams(request_id)

# Create and start EventSourceResponse
# SSE stream mode (original behavior)
Expand Down Expand Up @@ -507,9 +516,9 @@ async def sse_writer():
await writer.send(session_message)
except Exception:
logger.exception("SSE response error")
# Clean up the request stream if something goes wrong
if request_id and request_id in self._request_streams:
self._request_streams.pop(request_id, None)
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(request_id)

except Exception as err:
logger.exception("Error handling POST request")
Expand Down Expand Up @@ -581,12 +590,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
async def standalone_sse_writer():
try:
# Create a standalone message stream for server-initiated messages
standalone_stream_writer, standalone_stream_reader = (

self._request_streams[GET_STREAM_KEY] = (
anyio.create_memory_object_stream[EventMessage](0)
)

# Register this stream using the special key
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]

async with sse_stream_writer, standalone_stream_reader:
# Process messages from the standalone stream
Expand All @@ -603,8 +611,7 @@ async def standalone_sse_writer():
logger.exception(f"Error in standalone SSE writer: {e}")
finally:
logger.debug("Closing standalone SSE writer")
# Remove the stream from request_streams
self._request_streams.pop(GET_STREAM_KEY, None)
await self._clean_up_memory_streams(GET_STREAM_KEY)

# Create and start EventSourceResponse
response = EventSourceResponse(
Expand All @@ -618,8 +625,9 @@ async def standalone_sse_writer():
await response(request.scope, request.receive, send)
except Exception as e:
logger.exception(f"Error in standalone SSE response: {e}")
# Clean up the request stream
self._request_streams.pop(GET_STREAM_KEY, None)
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(GET_STREAM_KEY)

async def _handle_delete_request(self, request: Request, send: Send) -> None:
"""Handle DELETE requests for explicit session termination."""
Expand All @@ -636,15 +644,15 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
if not await self._validate_session(request, send):
return

self._terminate_session()
await self._terminate_session()

response = self._create_json_response(
None,
HTTPStatus.OK,
)
await response(request.scope, request.receive, send)

def _terminate_session(self) -> None:
async def _terminate_session(self) -> None:
"""Terminate the current session, closing all streams.

Once terminated, all requests with this session ID will receive 404 Not Found.
Expand All @@ -656,19 +664,26 @@ def _terminate_session(self) -> None:
# We need a copy of the keys to avoid modification during iteration
request_stream_keys = list(self._request_streams.keys())

# Close all request streams (synchronously)
# Close all request streams asynchronously
for key in request_stream_keys:
try:
# Get the stream
stream = self._request_streams.get(key)
if stream:
# We must use close() here, not aclose() since this is a sync method
stream.close()
await self._clean_up_memory_streams(key)
except Exception as e:
logger.debug(f"Error closing stream {key} during termination: {e}")

# Clear the request streams dictionary immediately
self._request_streams.clear()
try:
if self._read_stream_writer is not None:
await self._read_stream_writer.aclose()
if self._read_stream is not None:
await self._read_stream.aclose()
if self._write_stream_reader is not None:
await self._write_stream_reader.aclose()
if self._write_stream is not None:
await self._write_stream.aclose()
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same q about possibly narrowing down the exceptions we really want to suppress here (maybe all? but just flagging)

logger.debug(f"Error closing streams: {e}")

async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
"""Handle unsupported HTTP methods."""
Expand Down Expand Up @@ -756,10 +771,10 @@ async def send_event(event_message: EventMessage) -> None:

# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams:
msg_writer, msg_reader = anyio.create_memory_object_stream[
EventMessage
](0)
self._request_streams[stream_id] = msg_writer
self._request_streams[stream_id] = (
anyio.create_memory_object_stream[EventMessage](0)
)
msg_reader = self._request_streams[stream_id][1]

# Forward messages to SSE
async with msg_reader:
Expand All @@ -781,6 +796,9 @@ async def send_event(event_message: EventMessage) -> None:
await response(request.scope, request.receive, send)
except Exception as e:
logger.exception(f"Error in replay response: {e}")
finally:
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()

except Exception as e:
logger.exception(f"Error replaying events: {e}")
Expand Down Expand Up @@ -818,7 +836,9 @@ async def connect(

# Store the streams
self._read_stream_writer = read_stream_writer
self._read_stream = read_stream
self._write_stream_reader = write_stream_reader
self._write_stream = write_stream

# Start a task group for message routing
async with anyio.create_task_group() as tg:
Expand Down Expand Up @@ -863,7 +883,7 @@ async def message_router():
if request_stream_id in self._request_streams:
try:
# Send both the message and the event ID
await self._request_streams[request_stream_id].send(
await self._request_streams[request_stream_id][0].send(
EventMessage(message, event_id)
)
except (
Expand All @@ -872,6 +892,12 @@ async def message_router():
):
# Stream might be closed, remove from registry
self._request_streams.pop(request_stream_id, None)
else:
logging.debug(
f"""Request stream {request_stream_id} not found
for message. Still processing message as the client
might reconnect and replay."""
)
except Exception as e:
logger.exception(f"Error in message router: {e}")

Expand All @@ -882,9 +908,19 @@ async def message_router():
# Yield the streams for the caller to use
yield read_stream, write_stream
finally:
for stream in list(self._request_streams.values()):
for stream_id in list(self._request_streams.keys()):
try:
await stream.aclose()
except Exception:
await self._clean_up_memory_streams(stream_id)
except Exception as e:
logger.debug(f"Error closing request stream: {e}")
pass
self._request_streams.clear()

# Clean up the read and write streams
try:
await read_stream_writer.aclose()
await read_stream.aclose()
await write_stream_reader.aclose()
await write_stream.aclose()
except Exception as e:
logger.debug(f"Error closing streams: {e}")
31 changes: 16 additions & 15 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,29 +234,30 @@ async def handle_streamable_http(scope, receive, send):
event_store=event_store,
)

async with http_transport.connect() as streams:
read_stream, write_stream = streams

async def run_server():
async def run_server(task_status=None):
async with http_transport.connect() as streams:
read_stream, write_stream = streams
if task_status:
task_status.started()
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
)

if task_group is None:
response = Response(
"Internal Server Error: Task group is not initialized",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
await response(scope, receive, send)
return
if task_group is None:
response = Response(
"Internal Server Error: Task group is not initialized",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
await response(scope, receive, send)
return

# Store the instance before starting the task to prevent races
server_instances[http_transport.mcp_session_id] = http_transport
task_group.start_soon(run_server)
# Store the instance before starting the task to prevent races
server_instances[http_transport.mcp_session_id] = http_transport
await task_group.start(run_server)

await http_transport.handle_request(scope, receive, send)
await http_transport.handle_request(scope, receive, send)
else:
response = Response(
"Bad Request: No valid session ID provided",
Expand Down
Loading