diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 93e44150..9e8efa52 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -101,6 +101,77 @@ describe("StreamableHTTPClientTransport", () => { expect(lastCall[1].headers.get("mcp-session-id")).toBe("test-session-id"); }); + it("should terminate session with DELETE request", async () => { + // First, simulate getting a session ID + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-id" + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }), + }); + + await transport.send(message); + expect(transport.sessionId).toBe("test-session-id"); + + // Now terminate the session + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers() + }); + + await transport.terminateSession(); + + // Verify the DELETE request was sent with the session ID + const calls = (global.fetch as jest.Mock).mock.calls; + const lastCall = calls[calls.length - 1]; + expect(lastCall[1].method).toBe("DELETE"); + expect(lastCall[1].headers.get("mcp-session-id")).toBe("test-session-id"); + + // The session ID should be cleared after successful termination + expect(transport.sessionId).toBeUndefined(); + }); + + it("should handle 405 response when server doesn't support session termination", async () => { + // First, simulate getting a session ID + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-id" + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }), + }); + + await transport.send(message); + + // Now terminate the session, but server responds with 405 + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: false, + status: 405, + statusText: "Method Not Allowed", + headers: new Headers() + }); + + await expect(transport.terminateSession()).resolves.not.toThrow(); + }); + it("should handle 404 response when session expires", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 077b0f15..3462b2ab 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -1,5 +1,5 @@ import { Transport } from "../shared/transport.js"; -import { isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { EventSourceParserStream } from "eventsource-parser/stream"; @@ -420,7 +420,7 @@ export class StreamableHTTPClientTransport implements Transport { if (response.status === 202) { // if the accepted notification is initialized, we start the SSE stream // if it's supported by the server - if (isJSONRPCNotification(message) && message.method === "notifications/initialized") { + if (isInitializedNotification(message)) { // Start without a lastEventId since this is a fresh connection this._startOrAuthSse({ resumptionToken: undefined }).catch(err => this.onerror?.(err)); } @@ -467,4 +467,48 @@ export class StreamableHTTPClientTransport implements Transport { get sessionId(): string | undefined { return this._sessionId; } + + /** + * Terminates the current session by sending a DELETE request to the server. + * + * Clients that no longer need a particular session + * (e.g., because the user is leaving the client application) SHOULD send an + * HTTP DELETE to the MCP endpoint with the Mcp-Session-Id header to explicitly + * terminate the session. + * + * The server MAY respond with HTTP 405 Method Not Allowed, indicating that + * the server does not allow clients to terminate sessions. + */ + async terminateSession(): Promise { + if (!this._sessionId) { + return; // No session to terminate + } + + try { + const headers = await this._commonHeaders(); + + const init = { + ...this._requestInit, + method: "DELETE", + headers, + signal: this._abortController?.signal, + }; + + const response = await fetch(this._url, init); + + // We specifically handle 405 as a valid response according to the spec, + // meaning the server does not support explicit session termination + if (!response.ok && response.status !== 405) { + throw new StreamableHTTPError( + response.status, + `Failed to terminate session: ${response.statusText}` + ); + } + + this._sessionId = undefined; + } catch (error) { + this.onerror?.(error as Error); + throw error; + } + } } diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index d0d5408b..c1501a57 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -48,6 +48,7 @@ function printHelp(): void { console.log('\nAvailable commands:'); console.log(' connect [url] - Connect to MCP server (default: http://localhost:3000/mcp)'); console.log(' disconnect - Disconnect from server'); + console.log(' terminate-session - Terminate the current session'); console.log(' reconnect - Reconnect to the server'); console.log(' list-tools - List available tools'); console.log(' call-tool [args] - Call a tool with optional JSON arguments'); @@ -76,6 +77,10 @@ function commandLoop(): void { await disconnect(); break; + case 'terminate-session': + await terminateSession(); + break; + case 'reconnect': await reconnect(); break; @@ -249,6 +254,36 @@ async function disconnect(): Promise { } } +async function terminateSession(): Promise { + if (!client || !transport) { + console.log('Not connected.'); + return; + } + + try { + console.log('Terminating session with ID:', transport.sessionId); + await transport.terminateSession(); + console.log('Session terminated successfully'); + + // Check if sessionId was cleared after termination + if (!transport.sessionId) { + console.log('Session ID has been cleared'); + sessionId = undefined; + + // Also close the transport and clear client objects + await transport.close(); + console.log('Transport closed after session termination'); + client = null; + transport = null; + } else { + console.log('Server responded with 405 Method Not Allowed (session termination not supported)'); + console.log('Session ID is still active:', transport.sessionId); + } + } catch (error) { + console.error('Error terminating session:', error); + } +} + async function reconnect(): Promise { if (client) { await disconnect(); @@ -411,13 +446,24 @@ async function listResources(): Promise { async function cleanup(): Promise { if (client && transport) { try { + // First try to terminate the session gracefully + if (transport.sessionId) { + try { + console.log('Terminating session before exit...'); + await transport.terminateSession(); + console.log('Session terminated successfully'); + } catch (error) { + console.error('Error terminating session:', error); + } + } + + // Then close the transport await transport.close(); } catch (error) { console.error('Error closing transport:', error); } } - process.stdin.setRawMode(false); readline.close(); console.log('\nGoodbye!'); diff --git a/src/examples/server/jsonResponseStreamableHttp.ts b/src/examples/server/jsonResponseStreamableHttp.ts index 34ab65d1..101a581f 100644 --- a/src/examples/server/jsonResponseStreamableHttp.ts +++ b/src/examples/server/jsonResponseStreamableHttp.ts @@ -3,7 +3,7 @@ import { randomUUID } from 'node:crypto'; import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { z } from 'zod'; -import { CallToolResult } from '../../types.js'; +import { CallToolResult, isInitializeRequest } from '../../types.js'; // Create an MCP server with implementation details const server = new McpServer({ @@ -95,18 +95,17 @@ app.post('/mcp', async (req: Request, res: Response) => { transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), enableJsonResponse: true, // Enable JSON response mode + onsessioninitialized: (sessionId) => { + // Store the transport by session ID when session is initialized + // This avoids race conditions where requests might come in before the session is stored + console.log(`Session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } }); // Connect the transport to the MCP server BEFORE handling the request await server.connect(transport); - - // After handling the request, if we get a session ID back, store the transport await transport.handleRequest(req, res, req.body); - - // Store the transport by session ID for future requests - if (transport.sessionId) { - transports[transport.sessionId] = transport; - } return; // Already handled } else { // Invalid request - no session ID or not initialization request @@ -145,14 +144,6 @@ app.get('/mcp', async (req: Request, res: Response) => { res.status(405).set('Allow', 'POST').send('Method Not Allowed'); }); -// Helper function to detect initialize requests -function isInitializeRequest(body: unknown): boolean { - if (Array.isArray(body)) { - return body.some(msg => typeof msg === 'object' && msg !== null && 'method' in msg && msg.method === 'initialize'); - } - return typeof body === 'object' && body !== null && 'method' in body && body.method === 'initialize'; -} - // Start the server const PORT = 3000; app.listen(PORT, () => { diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 98333730..0ae0f910 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -3,7 +3,7 @@ import { randomUUID } from 'node:crypto'; import { McpServer } from '../../server/mcp.js'; import { EventStore, StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { z } from 'zod'; -import { CallToolResult, GetPromptResult, JSONRPCMessage, ReadResourceResult } from '../../types.js'; +import { CallToolResult, GetPromptResult, isInitializeRequest, JSONRPCMessage, ReadResourceResult } from '../../types.js'; // Create a simple in-memory EventStore for resumability class InMemoryEventStore implements EventStore { @@ -36,7 +36,7 @@ class InMemoryEventStore implements EventStore { * Replays events that occurred after a specific event ID * Implements EventStore.replayEventsAfter */ - async replayEventsAfter(lastEventId: string, + async replayEventsAfter(lastEventId: string, { send }: { send: (eventId: string, message: JSONRPCMessage) => Promise } ): Promise { if (!lastEventId || !this.events.has(lastEventId)) { @@ -247,19 +247,28 @@ app.post('/mcp', async (req: Request, res: Response) => { transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, // Enable resumability + onsessioninitialized: (sessionId) => { + // Store the transport by session ID when session is initialized + // This avoids race conditions where requests might come in before the session is stored + console.log(`Session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } }); + // Set up onclose handler to clean up transport when closed + transport.onclose = () => { + const sid = transport.sessionId; + if (sid && transports[sid]) { + console.log(`Transport closed for session ${sid}, removing from transports map`); + delete transports[sid]; + } + }; + // Connect the transport to the MCP server BEFORE handling the request // so responses can flow back through the same transport await server.connect(transport); - // After handling the request, if we get a session ID back, store the transport await transport.handleRequest(req, res, req.body); - - // Store the transport by session ID for future requests - if (transport.sessionId) { - transports[transport.sessionId] = transport; - } return; // Already handled } else { // Invalid request - no session ID or not initialization request @@ -312,13 +321,26 @@ app.get('/mcp', async (req: Request, res: Response) => { await transport.handleRequest(req, res); }); -// Helper function to detect initialize requests -function isInitializeRequest(body: unknown): boolean { - if (Array.isArray(body)) { - return body.some(msg => typeof msg === 'object' && msg !== null && 'method' in msg && msg.method === 'initialize'); +// Handle DELETE requests for session termination (according to MCP spec) +app.delete('/mcp', async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; } - return typeof body === 'object' && body !== null && 'method' in body && body.method === 'initialize'; -} + + console.log(`Received session termination request for session ${sessionId}`); + + try { + const transport = transports[sessionId]; + await transport.handleRequest(req, res); + } catch (error) { + console.error('Error handling session termination:', error); + if (!res.headersSent) { + res.status(500).send('Error processing session termination'); + } + } +}); // Start the server const PORT = 3000; @@ -351,6 +373,18 @@ app.listen(PORT, () => { // Handle server shutdown process.on('SIGINT', async () => { console.log('Shutting down server...'); + + // Close all active transports to properly clean up resources + for (const sessionId in transports) { + try { + console.log(`Closing transport for session ${sessionId}`); + await transports[sessionId].close(); + delete transports[sessionId]; + } catch (error) { + console.error(`Error closing transport for session ${sessionId}:`, error); + } + } await server.close(); + console.log('Server shutdown complete'); process.exit(0); }); \ No newline at end of file diff --git a/src/examples/server/standaloneSseWithGetStreamableHttp.ts b/src/examples/server/standaloneSseWithGetStreamableHttp.ts index f9d3696b..8c8c3baa 100644 --- a/src/examples/server/standaloneSseWithGetStreamableHttp.ts +++ b/src/examples/server/standaloneSseWithGetStreamableHttp.ts @@ -2,7 +2,7 @@ import express, { Request, Response } from 'express'; import { randomUUID } from 'node:crypto'; import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; -import { ReadResourceResult } from '../../types.js'; +import { isInitializeRequest, ReadResourceResult } from '../../types.js'; // Create an MCP server with implementation details const server = new McpServer({ @@ -52,17 +52,19 @@ app.post('/mcp', async (req: Request, res: Response) => { // New initialization request transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), + onsessioninitialized: (sessionId) => { + // Store the transport by session ID when session is initialized + // This avoids race conditions where requests might come in before the session is stored + console.log(`Session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } }); // Connect the transport to the MCP server await server.connect(transport); + // Handle the request - the onsessioninitialized callback will store the transport await transport.handleRequest(req, res, req.body); - - // Store the transport by session ID for future requests - if (transport.sessionId) { - transports[transport.sessionId] = transport; - } return; // Already handled } else { // Invalid request - no session ID or not initialization request @@ -107,13 +109,6 @@ app.get('/mcp', async (req: Request, res: Response) => { await transport.handleRequest(req, res); }); -// Helper function to detect initialize requests -function isInitializeRequest(body: unknown): boolean { - if (Array.isArray(body)) { - return body.some(msg => typeof msg === 'object' && msg !== null && 'method' in msg && msg.method === 'initialize'); - } - return typeof body === 'object' && body !== null && 'method' in body && body.method === 'initialize'; -} // Start the server const PORT = 3000; diff --git a/src/integration-tests/stateManagementStreamableHttp.test.ts b/src/integration-tests/stateManagementStreamableHttp.test.ts new file mode 100644 index 00000000..1e80b7b8 --- /dev/null +++ b/src/integration-tests/stateManagementStreamableHttp.test.ts @@ -0,0 +1,265 @@ +import { createServer, type Server } from 'node:http'; +import { AddressInfo } from 'node:net'; +import { randomUUID } from 'node:crypto'; +import { Client } from '../client/index.js'; +import { StreamableHTTPClientTransport } from '../client/streamableHttp.js'; +import { McpServer } from '../server/mcp.js'; +import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; +import { CallToolResultSchema, ListToolsResultSchema, ListResourcesResultSchema, ListPromptsResultSchema } from '../types.js'; +import { z } from 'zod'; + +describe('Streamable HTTP Transport Session Management', () => { + // Function to set up the server with optional session management + async function setupServer(withSessionManagement: boolean) { + const server: Server = createServer(); + const mcpServer = new McpServer( + { name: 'test-server', version: '1.0.0' }, + { + capabilities: { + logging: {}, + tools: {}, + resources: {}, + prompts: {} + } + } + ); + + // Add a simple resource + mcpServer.resource( + 'test-resource', + '/test', + { description: 'A test resource' }, + async () => ({ + contents: [{ + uri: '/test', + text: 'This is a test resource content' + }] + }) + ); + + mcpServer.prompt( + 'test-prompt', + 'A test prompt', + async () => ({ + messages: [{ + role: 'user', + content: { + type: 'text', + text: 'This is a test prompt' + } + }] + }) + ); + + mcpServer.tool( + 'greet', + 'A simple greeting tool', + { + name: z.string().describe('Name to greet').default('World'), + }, + async ({ name }) => { + return { + content: [{ type: 'text', text: `Hello, ${name}!` }] + }; + } + ); + + // Create transport with or without session management + const serverTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: withSessionManagement + ? () => randomUUID() // With session management, generate UUID + : () => undefined // Without session management, return undefined + }); + + await mcpServer.connect(serverTransport); + + server.on('request', async (req, res) => { + await serverTransport.handleRequest(req, res); + }); + + // Start the server on a random port + const baseUrl = await new Promise((resolve) => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + return { server, mcpServer, serverTransport, baseUrl }; + } + + describe('Stateless Mode', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + + beforeEach(async () => { + const setup = await setupServer(false); + server = setup.server; + mcpServer = setup.mcpServer; + serverTransport = setup.serverTransport; + baseUrl = setup.baseUrl; + }); + + afterEach(async () => { + // Clean up resources + await mcpServer.close().catch(() => { }); + await serverTransport.close().catch(() => { }); + server.close(); + }); + + it('should operate without session management', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Verify that no session ID was set + expect(transport.sessionId).toBeUndefined(); + + // List available tools + const toolsResult = await client.request({ + method: 'tools/list', + params: {} + }, ListToolsResultSchema); + + // Verify tools are accessible + expect(toolsResult.tools).toContainEqual(expect.objectContaining({ + name: 'greet' + })); + + // List available resources + const resourcesResult = await client.request({ + method: 'resources/list', + params: {} + }, ListResourcesResultSchema); + + // Verify resources result structure + expect(resourcesResult).toHaveProperty('resources'); + + // List available prompts + const promptsResult = await client.request({ + method: 'prompts/list', + params: {} + }, ListPromptsResultSchema); + + // Verify prompts result structure + expect(promptsResult).toHaveProperty('prompts'); + expect(promptsResult.prompts).toContainEqual(expect.objectContaining({ + name: 'test-prompt' + })); + + // Call the greeting tool + const greetingResult = await client.request({ + method: 'tools/call', + params: { + name: 'greet', + arguments: { + name: 'Stateless Transport' + } + } + }, CallToolResultSchema); + + // Verify tool result + expect(greetingResult.content).toEqual([ + { type: 'text', text: 'Hello, Stateless Transport!' } + ]); + + // Clean up + await transport.close(); + }); + }); + + describe('Stateful Mode', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + + beforeEach(async () => { + const setup = await setupServer(true); + server = setup.server; + mcpServer = setup.mcpServer; + serverTransport = setup.serverTransport; + baseUrl = setup.baseUrl; + }); + + afterEach(async () => { + // Clean up resources + await mcpServer.close().catch(() => { }); + await serverTransport.close().catch(() => { }); + server.close(); + }); + + it('should operate with session management', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Verify that a session ID was set + expect(transport.sessionId).toBeDefined(); + expect(typeof transport.sessionId).toBe('string'); + + // List available tools + const toolsResult = await client.request({ + method: 'tools/list', + params: {} + }, ListToolsResultSchema); + + // Verify tools are accessible + expect(toolsResult.tools).toContainEqual(expect.objectContaining({ + name: 'greet' + })); + + // List available resources + const resourcesResult = await client.request({ + method: 'resources/list', + params: {} + }, ListResourcesResultSchema); + + // Verify resources result structure + expect(resourcesResult).toHaveProperty('resources'); + + // List available prompts + const promptsResult = await client.request({ + method: 'prompts/list', + params: {} + }, ListPromptsResultSchema); + + // Verify prompts result structure + expect(promptsResult).toHaveProperty('prompts'); + expect(promptsResult.prompts).toContainEqual(expect.objectContaining({ + name: 'test-prompt' + })); + + // Call the greeting tool + const greetingResult = await client.request({ + method: 'tools/call', + params: { + name: 'greet', + arguments: { + name: 'Stateful Transport' + } + } + }, CallToolResultSchema); + + // Verify tool result + expect(greetingResult.content).toEqual([ + { type: 'text', text: 'Hello, Stateful Transport!' } + ]); + + // Clean up + await transport.close(); + }); + }); +}); \ No newline at end of file diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index efd5de1c..0794e4bb 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -187,15 +187,11 @@ describe("StreamableHTTPServerTransport", () => { expect(sessionId).toBeDefined(); // Try second initialize - const secondInitMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "initialize", - params: { - clientInfo: { name: "test-client-2", version: "1.0" }, - protocolVersion: "2025-03-26", - }, - id: "init-2", + const secondInitMessage = { + ...TEST_MESSAGES.initialize, + id: "second-init" }; + const response = await sendPostRequest(baseUrl, secondInitMessage); expect(response.status).toBe(400); @@ -1092,14 +1088,7 @@ describe("StreamableHTTPServerTransport in stateless mode", () => { }); it("should handle POST requests with various session IDs in stateless mode", async () => { - // Initialize the server first - await fetch(baseUrl, { - method: "POST", - headers: { "Content-Type": "application/json", Accept: "application/json, text/event-stream" }, - body: JSON.stringify({ - jsonrpc: "2.0", method: "initialize", params: { clientInfo: { name: "test-client", version: "1.0" }, protocolVersion: "2025-03-26" }, id: "init-1" - }), - }); + await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); // Try with a random session ID - should be accepted const response1 = await fetch(baseUrl, { @@ -1131,13 +1120,7 @@ describe("StreamableHTTPServerTransport in stateless mode", () => { // one standalone SSE stream at a time // Initialize the server first - await fetch(baseUrl, { - method: "POST", - headers: { "Content-Type": "application/json", Accept: "application/json, text/event-stream" }, - body: JSON.stringify({ - jsonrpc: "2.0", method: "initialize", params: { clientInfo: { name: "test-client", version: "1.0" }, protocolVersion: "2025-03-26" }, id: "init-1" - }), - }); + await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); // Open first SSE stream const stream1 = await fetch(baseUrl, { diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 31aad09c..ed52eb77 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -1,6 +1,6 @@ import { IncomingMessage, ServerResponse } from "node:http"; import { Transport } from "../shared/transport.js"; -import { isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId } from "../types.js"; +import { isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; import { randomUUID } from "node:crypto"; @@ -39,6 +39,15 @@ export interface StreamableHTTPServerTransportOptions { */ sessionIdGenerator: () => string | undefined; + /** + * A callback for session initialization events + * This is called when the server initializes a new session. + * Useful in cases when you need to register multiple mcp sessions + * and need to keep track of them. + * @param sessionId The generated session ID + */ + onsessioninitialized?: (sessionId: string) => void; + /** * If true, the server will return JSON responses instead of starting an SSE stream. * This can be useful for simple request/response scenarios without streaming. @@ -98,6 +107,7 @@ export class StreamableHTTPServerTransport implements Transport { private _enableJsonResponse: boolean = false; private _standaloneSseStreamId: string = '_GET_stream'; private _eventStore?: EventStore; + private _onsessioninitialized?: (sessionId: string) => void; sessionId?: string | undefined; onclose?: () => void; @@ -108,6 +118,7 @@ export class StreamableHTTPServerTransport implements Transport { this.sessionIdGenerator = options.sessionIdGenerator; this._enableJsonResponse = options.enableJsonResponse ?? false; this._eventStore = options.eventStore; + this._onsessioninitialized = options.onsessioninitialized; } /** @@ -328,9 +339,7 @@ export class StreamableHTTPServerTransport implements Transport { // Check if this is an initialization request // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ - const isInitializationRequest = messages.some( - msg => 'method' in msg && msg.method === 'initialize' - ); + const isInitializationRequest = messages.some(isInitializeRequest); if (isInitializationRequest) { // If it's a server with session management and the session ID is already set we should reject the request // to avoid re-initialization. @@ -359,6 +368,12 @@ export class StreamableHTTPServerTransport implements Transport { this.sessionId = this.sessionIdGenerator(); this._initialized = true; + // If we have a session ID and an onsessioninitialized handler, call it immediately + // This is needed in cases where the server needs to keep track of multiple sessions + if (this.sessionId && this._onsessioninitialized) { + this._onsessioninitialized(this.sessionId); + } + } // If an Mcp-Session-Id is returned by the server during initialization, // clients using the Streamable HTTP transport MUST include it @@ -400,7 +415,7 @@ export class StreamableHTTPServerTransport implements Transport { // Store the response for this request to send messages back through this connection // We need to track by request ID to maintain the connection for (const message of messages) { - if ('method' in message && 'id' in message) { + if (isJSONRPCRequest(message)) { this._streamMapping.set(streamId, res); this._requestToStreamMapping.set(message.id, streamId); } @@ -520,7 +535,7 @@ export class StreamableHTTPServerTransport implements Transport { async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { let requestId = options?.relatedRequestId; - if ('result' in message || 'error' in message) { + if (isJSONRPCResponse(message) || isJSONRPCError(message)) { // If the message is a response, use the request ID from the message requestId = message.id; } @@ -530,7 +545,7 @@ export class StreamableHTTPServerTransport implements Transport { // Those will be sent via dedicated response SSE streams if (requestId === undefined) { // For standalone SSE streams, we can only send requests and notifications - if ('result' in message || 'error' in message) { + if (isJSONRPCResponse(message) || isJSONRPCError(message)) { throw new Error("Cannot send a response on a standalone SSE stream unless resuming a previous client request"); } const standaloneSse = this._streamMapping.get(this._standaloneSseStreamId) diff --git a/src/types.ts b/src/types.ts index db6bf125..8ac41372 100644 --- a/src/types.ts +++ b/src/types.ts @@ -248,6 +248,10 @@ export const InitializeRequestSchema = RequestSchema.extend({ }), }); +export const isInitializeRequest = (value: unknown): value is InitializeRequest => + InitializeRequestSchema.safeParse(value).success; + + /** * Capabilities that a server may support. Known capabilities are defined here, in this schema, but this is not a closed set: any server can define its own, additional capabilities. */ @@ -337,6 +341,9 @@ export const InitializedNotificationSchema = NotificationSchema.extend({ method: z.literal("notifications/initialized"), }); +export const isInitializedNotification = (value: unknown): value is InitializedNotification => + InitializedNotificationSchema.safeParse(value).success; + /* Ping */ /** * A ping, issued by either the server or the client, to check that the other party is still alive. The receiver must promptly respond, or else may be disconnected.