-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix handle sse disconnect event to free read_stream_writers[session_id] #582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v1.3.x
Are you sure you want to change the base?
Changes from all commits
f7d33ad
e8e0491
c7ef8a2
bda260a
36bb4c6
ee8d1eb
74701dd
49262a9
dbe05d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
Example usage: | ||
``` | ||
# Create an SSE transport at an endpoint | ||
from starlette.responses import Response | ||
|
||
sse = SseServerTransport("/messages/") | ||
|
||
# Create Starlette routes for SSE and message handling | ||
|
@@ -22,6 +24,7 @@ async def handle_sse(request): | |
await app.run( | ||
streams[0], streams[1], app.create_initialization_options() | ||
) | ||
return Response("MCP SSE") | ||
|
||
# Create and run Starlette app | ||
starlette_app = Starlette(routes=routes) | ||
|
@@ -43,7 +46,7 @@ async def handle_sse(request): | |
from sse_starlette import EventSourceResponse | ||
from starlette.requests import Request | ||
from starlette.responses import Response | ||
from starlette.types import Receive, Scope, Send | ||
from starlette.types import Message, Receive, Scope, Send | ||
|
||
import mcp.types as types | ||
|
||
|
@@ -120,9 +123,19 @@ async def sse_writer(): | |
} | ||
) | ||
|
||
async def handle_see_disconnect(message: Message) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is also handled in #586. let's review that one first and then we can review this one as we have a test here -- thank you for adding it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left my thought here And I don't think add callback to connect_sse is proper |
||
logger.debug(f"Disconnect sse {session_id}") | ||
del self._read_stream_writers[session_id] | ||
await read_stream.aclose() | ||
await read_stream_writer.aclose() | ||
await write_stream.aclose() | ||
await write_stream_reader.aclose() | ||
|
||
async with anyio.create_task_group() as tg: | ||
response = EventSourceResponse( | ||
content=sse_stream_reader, data_sender_callable=sse_writer | ||
content=sse_stream_reader, | ||
data_sender_callable=sse_writer, | ||
client_close_handler_callable=handle_see_disconnect, | ||
) | ||
logger.debug("Starting SSE response task") | ||
tg.start_soon(response, scope, receive, send) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import asyncio | ||
from uuid import UUID | ||
|
||
import pytest | ||
from starlette.types import Message, Scope | ||
|
||
from mcp.server.sse import SseServerTransport | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_sse_disconnect_handle(): | ||
transport = SseServerTransport(endpoint="/sse") | ||
# Create a minimal ASGI scope for an HTTP GET request | ||
scope: Scope = { | ||
"type": "http", | ||
"method": "GET", | ||
"path": "/sse", | ||
"headers": [], | ||
} | ||
send_disconnect = False | ||
|
||
# Dummy receive and send functions | ||
async def receive() -> dict: | ||
nonlocal send_disconnect | ||
if not send_disconnect: | ||
send_disconnect = True | ||
return {"type": "http.request"} | ||
else: | ||
return {"type": "http.disconnect"} | ||
|
||
async def send(message: Message) -> None: | ||
await asyncio.sleep(0) | ||
|
||
# Run the connect_sse context manager | ||
async with transport.connect_sse(scope, receive, send) as ( | ||
read_stream, | ||
write_stream, | ||
): | ||
# Assert that streams are provided | ||
assert read_stream is not None | ||
assert write_stream is not None | ||
|
||
# There should be exactly one session | ||
assert len(transport._read_stream_writers) == 1 | ||
# Check that the session key is a UUID | ||
session_id = next(iter(transport._read_stream_writers.keys())) | ||
assert isinstance(session_id, UUID) | ||
|
||
# Check that the writer is still open | ||
writer = transport._read_stream_writers[session_id] | ||
assert writer is not None | ||
|
||
# After context exits, session should be cleaned up | ||
assert len(transport._read_stream_writers) == 0 |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this was already fixed by #612