diff --git a/src/client/index.test.ts b/src/client/index.test.ts index a1b43b2b..36dd6518 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -66,6 +66,9 @@ test("should initialize with matching protocol version", async () => { protocolVersion: LATEST_PROTOCOL_VERSION, }), }), + expect.objectContaining({ + relatedRequestId: undefined, + }), ); // Should have the instructions returned diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 2e91a568..f33c669f 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -11,6 +11,7 @@ import { ListPromptsResultSchema, GetPromptResultSchema, CompleteResultSchema, + LoggingMessageNotificationSchema, } from "../types.js"; import { ResourceTemplate } from "./mcp.js"; import { completable } from "./completable.js"; @@ -85,6 +86,8 @@ describe("ResourceTemplate", () => { const abortController = new AbortController(); const result = await template.listCallback?.({ signal: abortController.signal, + sendRequest: () => { throw new Error("Not implemented") }, + sendNotification: () => { throw new Error("Not implemented") } }); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); @@ -318,7 +321,7 @@ describe("tool()", () => { // This should succeed mcpServer.tool("tool1", () => ({ content: [] })); - + // This should also succeed and not throw about request handlers mcpServer.tool("tool2", () => ({ content: [] })); }); @@ -376,6 +379,63 @@ describe("tool()", () => { expect(receivedSessionId).toBe("test-session-123"); }); + test("should provide sendNotification within tool call", async () => { + const mcpServer = new McpServer( + { + name: "test server", + version: "1.0", + }, + { capabilities: { logging: {} } }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + let receivedLogMessage: string | undefined; + const loggingMessage = "hello here is log message 1"; + + client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { + receivedLogMessage = notification.params.data as string; + }); + + mcpServer.tool("test-tool", async ({ sendNotification }) => { + await sendNotification({ method: "notifications/message", params: { level: "debug", data: loggingMessage } }); + return { + content: [ + { + type: "text", + text: "Test response", + }, + ], + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + await client.request( + { + method: "tools/call", + params: { + name: "test-tool", + }, + }, + CallToolResultSchema, + ); + expect(receivedLogMessage).toBe(loggingMessage); + }); + test("should allow client to call server tools", async () => { const mcpServer = new McpServer({ name: "test server", @@ -815,7 +875,7 @@ describe("resource()", () => { }, ], })); - + // This should also succeed and not throw about request handlers mcpServer.resource("resource2", "test://resource2", async () => ({ contents: [ @@ -1321,7 +1381,7 @@ describe("prompt()", () => { }, ], })); - + // This should also succeed and not throw about request handlers mcpServer.prompt("prompt2", async () => ({ messages: [ diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 8f4a909c..484084fc 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -37,6 +37,8 @@ import { PromptArgument, GetPromptResult, ReadResourceResult, + ServerRequest, + ServerNotification, } from "../types.js"; import { Completable, CompletableDef } from "./completable.js"; import { UriTemplate, Variables } from "../shared/uriTemplate.js"; @@ -694,9 +696,9 @@ export type ToolCallback = Args extends ZodRawShape ? ( args: z.objectOutputType, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; + : (extra: RequestHandlerExtra) => CallToolResult | Promise; type RegisteredTool = { description?: string; @@ -717,7 +719,7 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => ListResourcesResult | Promise; /** @@ -725,7 +727,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => ReadResourceResult | Promise; type RegisteredResource = { @@ -740,7 +742,7 @@ type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => ReadResourceResult | Promise; type RegisteredResourceTemplate = { @@ -760,9 +762,9 @@ export type PromptCallback< > = Args extends PromptArgsRawShape ? ( args: z.objectOutputType, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + : (extra: RequestHandlerExtra) => GetPromptResult | Promise; type RegisteredPrompt = { description?: string; diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts new file mode 100644 index 00000000..aff9e511 --- /dev/null +++ b/src/server/streamableHttp.test.ts @@ -0,0 +1,1224 @@ +import { IncomingMessage, ServerResponse } from "node:http"; +import { StreamableHTTPServerTransport } from "./streamableHttp.js"; +import { JSONRPCMessage } from "../types.js"; +import { Readable } from "node:stream"; +import { randomUUID } from "node:crypto"; +// Mock IncomingMessage +function createMockRequest(options: { + method: string; + headers: Record; + body?: string; +}): IncomingMessage { + const readable = new Readable(); + readable._read = () => { }; + if (options.body) { + readable.push(options.body); + readable.push(null); + } + + return Object.assign(readable, { + method: options.method, + headers: options.headers, + }) as IncomingMessage; +} + +// Mock ServerResponse +function createMockResponse(): jest.Mocked { + const response = { + writeHead: jest.fn().mockReturnThis(), + write: jest.fn().mockReturnThis(), + end: jest.fn().mockReturnThis(), + on: jest.fn().mockReturnThis(), + emit: jest.fn().mockReturnThis(), + getHeader: jest.fn(), + setHeader: jest.fn(), + } as unknown as jest.Mocked; + return response; +} + +describe("StreamableHTTPServerTransport", () => { + let transport: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + let mockRequest: string; + + beforeEach(() => { + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + }); + mockResponse = createMockResponse(); + mockRequest = JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe("Session Management", () => { + it("should generate session ID during initialization", async () => { + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initializeMessage), + }); + + expect(transport.sessionId).toBeUndefined(); + expect(transport["_initialized"]).toBe(false); + + await transport.handleRequest(req, mockResponse); + + expect(transport.sessionId).toBeDefined(); + expect(transport["_initialized"]).toBe(true); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "mcp-session-id": transport.sessionId, + }) + ); + }); + + it("should reject second initialization request", async () => { + // First initialize + const initMessage1: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const req1 = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage1), + }); + + await transport.handleRequest(req1, mockResponse); + expect(transport["_initialized"]).toBe(true); + + // Reset mock for second request + mockResponse.writeHead.mockClear(); + mockResponse.end.mockClear(); + + // Try second initialize + const initMessage2: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-2", + }; + + const req2 = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage2), + }); + + await transport.handleRequest(req2, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Invalid Request: Server already initialized"')); + }); + + it("should reject batch initialize request", async () => { + const batchInitialize: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }, + { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client-2", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-2", + } + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(batchInitialize), + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Invalid Request: Only one initialization request is allowed"')); + }); + + it("should reject invalid session ID", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try with an invalid session ID + const req = createMockRequest({ + method: "POST", + headers: { + "mcp-session-id": "invalid-session-id", + "accept": "application/json, text/event-stream", + "content-type": "application/json", + }, + body: mockRequest, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(404); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Session not found"')); + }); + + it("should reject non-initialization requests without session ID with 400 Bad Request", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try without session ID + const req = createMockRequest({ + method: "POST", + headers: { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + // No mcp-session-id header + }, + body: mockRequest + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"')); + }); + + it("should reject requests to uninitialized server", async () => { + // Create a new transport that hasn't been initialized + const uninitializedTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + }); + + const req = createMockRequest({ + method: "POST", + headers: { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + "mcp-session-id": "any-session-id", + }, + body: mockRequest + }); + + await uninitializedTransport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Server not initialized"')); + }); + + it("should reject session ID as array", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try with an array session ID + const req = createMockRequest({ + method: "POST", + headers: { + "mcp-session-id": ["session1", "session2"], + "accept": "application/json, text/event-stream", + "content-type": "application/json", + }, + body: mockRequest, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header must be a single value"')); + }); + }); + describe("Mode without state management", () => { + let transportWithoutState: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + + beforeEach(async () => { + transportWithoutState = new StreamableHTTPServerTransport({ sessionIdGenerator: () => undefined }); + mockResponse = createMockResponse(); + + // Initialize the transport for each test + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transportWithoutState.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + }); + + it("should not include session ID in response headers when in mode without state management", async () => { + // Use a non-initialization request + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(message), + }); + + await transportWithoutState.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalled(); + // Extract the headers from writeHead call + const headers = mockResponse.writeHead.mock.calls[0][1]; + expect(headers).not.toHaveProperty("mcp-session-id"); + }); + + it("should not validate session ID in mode without state management", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": "invalid-session-id", // This would cause a 404 in mode with state management + }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1 + }), + }); + + await transportWithoutState.handleRequest(req, mockResponse); + + // Should still get 200 OK, not 404 Not Found + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.not.objectContaining({ + "mcp-session-id": expect.anything(), + }) + ); + }); + + it("should handle POST requests without session validation in mode without state management", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": "non-existent-session-id", // This would be rejected in mode with state management + }, + body: JSON.stringify(message), + }); + + const onMessageMock = jest.fn(); + transportWithoutState.onmessage = onMessageMock; + + await transportWithoutState.handleRequest(req, mockResponse); + + // Message should be processed despite invalid session ID + expect(onMessageMock).toHaveBeenCalledWith(message); + }); + + it("should work with a mix of requests with and without session IDs in mode without state management", async () => { + // First request without session ID + const req1 = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + accept: "application/json, text/event-stream", + }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }) + }); + + await transportWithoutState.handleRequest(req1, mockResponse); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + + // Reset mock for second request + mockResponse.writeHead.mockClear(); + + // Second request with a session ID (which would be invalid in mode with state management) + const req2 = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + accept: "application/json, text/event-stream", + "mcp-session-id": "some-random-session-id", + }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test2", + params: {}, + id: "test-id-2" + }) + }); + + await transportWithoutState.handleRequest(req2, mockResponse); + + // Should still succeed + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + }); + + it("should handle initialization in mode without state management", async () => { + const transportWithoutState = new StreamableHTTPServerTransport({ sessionIdGenerator: () => undefined }); + + // Initialize message + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + expect(transportWithoutState.sessionId).toBeUndefined(); + expect(transportWithoutState["_initialized"]).toBe(false); + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initializeMessage), + }); + + const newResponse = createMockResponse(); + await transportWithoutState.handleRequest(req, newResponse); + + // After initialization, the sessionId should still be undefined + expect(transportWithoutState.sessionId).toBeUndefined(); + expect(transportWithoutState["_initialized"]).toBe(true); + + // Headers should NOT include session ID in mode without state management + const headers = newResponse.writeHead.mock.calls[0][1]; + expect(headers).not.toHaveProperty("mcp-session-id"); + }); + }); + + describe("Request Handling", () => { + // Initialize the transport before tests that need initialization + beforeEach(async () => { + // For tests that need initialization, initialize here + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + }); + + it("should reject GET requests for SSE with 405 Method Not Allowed", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(405, expect.objectContaining({ + "Allow": "POST, DELETE" + })); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('Method not allowed')); + }); + + it("should reject POST requests without proper Accept header", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "mcp-session-id": transport.sessionId, + }, + body: JSON.stringify(message), + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(406); + }); + + it("should properly handle JSON-RPC request messages in POST requests", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, + }, + body: JSON.stringify(message), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledWith(message); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + }); + + it("should properly handle JSON-RPC notification or response messages in POST requests", async () => { + const notification: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, + }, + body: JSON.stringify(notification), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledWith(notification); + expect(mockResponse.writeHead).toHaveBeenCalledWith(202); + }); + + it("should handle batch notification messages properly with 202 response", async () => { + const batchMessages: JSONRPCMessage[] = [ + { jsonrpc: "2.0", method: "test1", params: {} }, + { jsonrpc: "2.0", method: "test2", params: {} }, + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, + }, + body: JSON.stringify(batchMessages), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledTimes(2); + expect(mockResponse.writeHead).toHaveBeenCalledWith(202); + }); + + it("should handle batch request messages with SSE when Accept header includes text/event-stream", async () => { + const batchMessages: JSONRPCMessage[] = [ + { jsonrpc: "2.0", method: "test1", params: {}, id: "req1" }, + { jsonrpc: "2.0", method: "test2", params: {}, id: "req2" }, + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "text/event-stream, application/json", + "mcp-session-id": transport.sessionId, + }, + body: JSON.stringify(batchMessages), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + mockResponse = createMockResponse(); // Create fresh mock + await transport.handleRequest(req, mockResponse); + + // Should establish SSE connection + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream" + }) + ); + expect(onMessageMock).toHaveBeenCalledTimes(2); + // Stream should remain open until responses are sent + expect(mockResponse.end).not.toHaveBeenCalled(); + }); + + it("should reject unsupported Content-Type", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "text/plain", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, + }, + body: "test", + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(415); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + }); + + it("should properly handle DELETE requests and close session", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try DELETE with proper session ID + const req = createMockRequest({ + method: "DELETE", + headers: { + "mcp-session-id": transport.sessionId, + }, + }); + + const onCloseMock = jest.fn(); + transport.onclose = onCloseMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(200); + expect(onCloseMock).toHaveBeenCalled(); + }); + + it("should reject DELETE requests with invalid session ID", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try DELETE with invalid session ID + const req = createMockRequest({ + method: "DELETE", + headers: { + "mcp-session-id": "invalid-session-id", + }, + }); + + const onCloseMock = jest.fn(); + transport.onclose = onCloseMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(404); + expect(onCloseMock).not.toHaveBeenCalled(); + }); + }); + + describe("SSE Response Handling", () => { + beforeEach(async () => { + // Initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + }); + + it("should send response messages as SSE events", async () => { + // Setup a POST request with JSON-RPC request that accepts SSE + const requestMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-req-id" + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId + }, + body: JSON.stringify(requestMessage) + }); + + await transport.handleRequest(req, mockResponse); + + // Send a response to the request + const responseMessage: JSONRPCMessage = { + jsonrpc: "2.0", + result: { value: "test-result" }, + id: "test-req-id" + }; + + await transport.send(responseMessage, { relatedRequestId: "test-req-id" }); + + // Verify response was sent as SSE event + expect(mockResponse.write).toHaveBeenCalledWith( + expect.stringContaining(`event: message\ndata: ${JSON.stringify(responseMessage)}\n\n`) + ); + + // Stream should be closed after sending response + expect(mockResponse.end).toHaveBeenCalled(); + }); + + it("should keep stream open when sending intermediate notifications and requests", async () => { + // Setup a POST request with JSON-RPC request that accepts SSE + const requestMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-req-id" + }; + + // Create fresh response for this test + mockResponse = createMockResponse(); + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId + }, + body: JSON.stringify(requestMessage) + }); + + await transport.handleRequest(req, mockResponse); + + // Send an intermediate notification + const notification: JSONRPCMessage = { + jsonrpc: "2.0", + method: "progress", + params: { progress: "50%" } + }; + + await transport.send(notification, { relatedRequestId: "test-req-id" }); + + // Stream should remain open + expect(mockResponse.end).not.toHaveBeenCalled(); + + // Send the final response + const responseMessage: JSONRPCMessage = { + jsonrpc: "2.0", + result: { value: "test-result" }, + id: "test-req-id" + }; + + await transport.send(responseMessage, { relatedRequestId: "test-req-id" }); + + // Now stream should be closed + expect(mockResponse.end).toHaveBeenCalled(); + }); + }); + + describe("Message Targeting", () => { + beforeEach(async () => { + // Initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + }); + + it("should send response messages to the connection that sent the request", async () => { + // Create request with two separate connections + const requestMessage1: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test1", + params: {}, + id: "req-id-1", + }; + + const mockResponse1 = createMockResponse(); + const req1 = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId + }, + body: JSON.stringify(requestMessage1), + }); + await transport.handleRequest(req1, mockResponse1); + + const requestMessage2: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test2", + params: {}, + id: "req-id-2", + }; + + const mockResponse2 = createMockResponse(); + const req2 = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId + }, + body: JSON.stringify(requestMessage2), + }); + await transport.handleRequest(req2, mockResponse2); + + // Send responses with matching IDs + const responseMessage1: JSONRPCMessage = { + jsonrpc: "2.0", + result: { success: true }, + id: "req-id-1", + }; + + await transport.send(responseMessage1, { relatedRequestId: "req-id-1" }); + + const responseMessage2: JSONRPCMessage = { + jsonrpc: "2.0", + result: { success: true }, + id: "req-id-2", + }; + + await transport.send(responseMessage2, { relatedRequestId: "req-id-2" }); + + // Verify responses were sent to the right connections + expect(mockResponse1.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(responseMessage1)) + ); + + expect(mockResponse2.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(responseMessage2)) + ); + + // Verify responses were not sent to the wrong connections + const resp1HasResp2 = mockResponse1.write.mock.calls.some(call => + typeof call[0] === 'string' && call[0].includes(JSON.stringify(responseMessage2)) + ); + expect(resp1HasResp2).toBe(false); + + const resp2HasResp1 = mockResponse2.write.mock.calls.some(call => + typeof call[0] === 'string' && call[0].includes(JSON.stringify(responseMessage1)) + ); + expect(resp2HasResp1).toBe(false); + }); + }); + + describe("Error Handling", () => { + it("should return 400 error for invalid JSON data", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: "invalid json", + }); + + const onErrorMock = jest.fn(); + transport.onerror = onErrorMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"code":-32700')); + expect(onErrorMock).toHaveBeenCalled(); + }); + + it("should return 400 error for invalid JSON-RPC messages", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify({ invalid: "message" }), + }); + + const onErrorMock = jest.fn(); + transport.onerror = onErrorMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(onErrorMock).toHaveBeenCalled(); + }); + }); + + describe("Handling Pre-Parsed Body", () => { + beforeEach(async () => { + // Initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + }); + + it("should accept pre-parsed request body", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "pre-parsed-test", + }; + + // Create a request without actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, + }, + // No body provided here - it will be passed as parsedBody + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, message); + + // Verify the message was processed correctly + expect(onMessageMock).toHaveBeenCalledWith(message); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + }); + + it("should handle pre-parsed batch messages", async () => { + const batchMessages: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + method: "method1", + params: { data: "test1" }, + id: "batch1" + }, + { + jsonrpc: "2.0", + method: "method2", + params: { data: "test2" }, + id: "batch2" + }, + ]; + + // Create a request without actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, + }, + // No body provided here - it will be passed as parsedBody + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, batchMessages); + + // Should be called for each message in the batch + expect(onMessageMock).toHaveBeenCalledTimes(2); + expect(onMessageMock).toHaveBeenCalledWith(batchMessages[0]); + expect(onMessageMock).toHaveBeenCalledWith(batchMessages[1]); + }); + + it("should prefer pre-parsed body over request body", async () => { + const requestBodyMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "fromRequestBody", + params: {}, + id: "request-body", + }; + + const parsedBodyMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "fromParsedBody", + params: {}, + id: "parsed-body", + }; + + // Create a request with actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, + }, + body: JSON.stringify(requestBodyMessage), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, parsedBodyMessage); + + // Should use the parsed body instead of the request body + expect(onMessageMock).toHaveBeenCalledWith(parsedBodyMessage); + expect(onMessageMock).not.toHaveBeenCalledWith(requestBodyMessage); + }); + }); +}); \ No newline at end of file diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts new file mode 100644 index 00000000..34b4fd95 --- /dev/null +++ b/src/server/streamableHttp.ts @@ -0,0 +1,397 @@ +import { IncomingMessage, ServerResponse } from "node:http"; +import { Transport } from "../shared/transport.js"; +import { JSONRPCMessage, JSONRPCMessageSchema, RequestId } from "../types.js"; +import getRawBody from "raw-body"; +import contentType from "content-type"; + +const MAXIMUM_MESSAGE_SIZE = "4mb"; + +/** + * Configuration options for StreamableHTTPServerTransport + */ +export interface StreamableHTTPServerTransportOptions { + /** + * Function that generates a session ID for the transport. + * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) + * + * Return undefined to disable session management. + */ + sessionIdGenerator: () => string | undefined; + + + +} + +/** + * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. + * It supports both SSE streaming and direct HTTP responses. + * + * Usage example: + * + * ```typescript + * // Stateful mode - server sets the session ID + * const statefulTransport = new StreamableHTTPServerTransport({ + * sessionId: randomUUID(), + * }); + * + * // Stateless mode - explicitly set session ID to undefined + * const statelessTransport = new StreamableHTTPServerTransport({ + * sessionId: undefined, + * }); + * + * // Using with pre-parsed request body + * app.post('/mcp', (req, res) => { + * transport.handleRequest(req, res, req.body); + * }); + * ``` + * + * In stateful mode: + * - Session ID is generated and included in response headers + * - Session ID is always included in initialization responses + * - Requests with invalid session IDs are rejected with 404 Not Found + * - Non-initialization requests without a session ID are rejected with 400 Bad Request + * - State is maintained in-memory (connections, message history) + * + * In stateless mode: + * - Session ID is only included in initialization responses + * - No session validation is performed + */ +export class StreamableHTTPServerTransport implements Transport { + // when sessionId is not set (undefined), it means the transport is in stateless mode + private sessionIdGenerator: () => string | undefined; + private _started: boolean = false; + private _sseResponseMapping: Map = new Map(); + private _initialized: boolean = false; + + sessionId?: string | undefined; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + constructor(options: StreamableHTTPServerTransportOptions) { + this.sessionIdGenerator = options.sessionIdGenerator; + } + + /** + * Starts the transport. This is required by the Transport interface but is a no-op + * for the Streamable HTTP transport as connections are managed per-request. + */ + async start(): Promise { + if (this._started) { + throw new Error("Transport already started"); + } + this._started = true; + } + + /** + * Handles an incoming HTTP request, whether GET or POST + */ + async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { + if (req.method === "POST") { + await this.handlePostRequest(req, res, parsedBody); + } else if (req.method === "DELETE") { + await this.handleDeleteRequest(req, res); + } else { + await this.handleUnsupportedRequest(res); + } + } + + /** + * Handles unsupported requests (GET, PUT, PATCH, etc.) + * For now we support only POST and DELETE requests. Support for GET for SSE connections will be added later. + */ + private async handleUnsupportedRequest(res: ServerResponse): Promise { + res.writeHead(405, { + "Allow": "POST, DELETE" + }).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Method not allowed." + }, + id: null + })); + } + + /** + * Handles POST requests containing JSON-RPC messages + */ + private async handlePostRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { + try { + // Validate the Accept header + const acceptHeader = req.headers.accept; + // The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types. + if (!acceptHeader?.includes("application/json") || !acceptHeader.includes("text/event-stream")) { + res.writeHead(406).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Not Acceptable: Client must accept both application/json and text/event-stream" + }, + id: null + })); + return; + } + + const ct = req.headers["content-type"]; + if (!ct || !ct.includes("application/json")) { + res.writeHead(415).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Unsupported Media Type: Content-Type must be application/json" + }, + id: null + })); + return; + } + + let rawMessage; + if (parsedBody !== undefined) { + rawMessage = parsedBody; + } else { + const parsedCt = contentType.parse(ct); + const body = await getRawBody(req, { + limit: MAXIMUM_MESSAGE_SIZE, + encoding: parsedCt.parameters.charset ?? "utf-8", + }); + rawMessage = JSON.parse(body.toString()); + } + + let messages: JSONRPCMessage[]; + + // handle batch and single messages + if (Array.isArray(rawMessage)) { + messages = rawMessage.map(msg => JSONRPCMessageSchema.parse(msg)); + } else { + messages = [JSONRPCMessageSchema.parse(rawMessage)]; + } + + // Check if this is an initialization request + // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ + const isInitializationRequest = messages.some( + msg => 'method' in msg && msg.method === 'initialize' + ); + if (isInitializationRequest) { + if (this._initialized) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32600, + message: "Invalid Request: Server already initialized" + }, + id: null + })); + return; + } + if (messages.length > 1) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32600, + message: "Invalid Request: Only one initialization request is allowed" + }, + id: null + })); + return; + } + this.sessionId = this.sessionIdGenerator(); + this._initialized = true; + const headers: Record = {}; + + if (this.sessionId !== undefined) { + headers["mcp-session-id"] = this.sessionId; + } + + // Process initialization messages before responding + for (const message of messages) { + this.onmessage?.(message); + } + + res.writeHead(200, headers).end(); + return; + } + // If an Mcp-Session-Id is returned by the server during initialization, + // clients using the Streamable HTTP transport MUST include it + // in the Mcp-Session-Id header on all of their subsequent HTTP requests. + if (!isInitializationRequest && !this.validateSession(req, res)) { + return; + } + + + // check if it contains requests + const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg); + const hasOnlyNotificationsOrResponses = messages.every(msg => + ('method' in msg && !('id' in msg)) || ('result' in msg || 'error' in msg)); + + if (hasOnlyNotificationsOrResponses) { + // if it only contains notifications or responses, return 202 + res.writeHead(202).end(); + + // handle each message + for (const message of messages) { + this.onmessage?.(message); + } + } else if (hasRequests) { + const headers: Record = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }; + + // After initialization, always include the session ID if we have one + if (this.sessionId !== undefined) { + headers["mcp-session-id"] = this.sessionId; + } + + res.writeHead(200, headers); + + // Store the response for this request to send messages back through this connection + // We need to track by request ID to maintain the connection + for (const message of messages) { + if ('method' in message && 'id' in message) { + this._sseResponseMapping.set(message.id, res); + } + } + + // handle each message + for (const message of messages) { + this.onmessage?.(message); + } + // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses + // This will be handled by the send() method when responses are ready + } + } catch (error) { + // return JSON-RPC formatted error + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32700, + message: "Parse error", + data: String(error) + }, + id: null + })); + this.onerror?.(error as Error); + } + } + + /** + * Handles DELETE requests to terminate sessions + */ + private async handleDeleteRequest(req: IncomingMessage, res: ServerResponse): Promise { + if (!this.validateSession(req, res)) { + return; + } + await this.close(); + res.writeHead(200).end(); + } + + /** + * Validates session ID for non-initialization requests + * Returns true if the session is valid, false otherwise + */ + private validateSession(req: IncomingMessage, res: ServerResponse): boolean { + if (!this._initialized) { + // If the server has not been initialized yet, reject all requests + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Server not initialized" + }, + id: null + })); + return false; + } + if (this.sessionId === undefined) { + // If the session ID is not set, the session management is disabled + // and we don't need to validate the session ID + return true; + } + const sessionId = req.headers["mcp-session-id"]; + + if (!sessionId) { + // Non-initialization requests without a session ID should return 400 Bad Request + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Mcp-Session-Id header is required" + }, + id: null + })); + return false; + } else if (Array.isArray(sessionId)) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Mcp-Session-Id header must be a single value" + }, + id: null + })); + return false; + } + else if (sessionId !== this.sessionId) { + // Reject requests with invalid session ID with 404 Not Found + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return false; + } + + return true; + } + + + async close(): Promise { + // Close all SSE connections + this._sseResponseMapping.forEach((response) => { + response.end(); + }); + this._sseResponseMapping.clear(); + this.onclose?.(); + } + + async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { + const relatedRequestId = options?.relatedRequestId; + // SSE connections are established per POST request, for now we don't support it through the GET + // this will be changed when we implement the GET SSE connection + if (relatedRequestId === undefined) { + throw new Error("relatedRequestId is required for Streamable HTTP transport"); + } + + const sseResponse = this._sseResponseMapping.get(relatedRequestId); + if (!sseResponse) { + throw new Error(`No SSE connection established for request ID: ${String(relatedRequestId)}`); + } + + // Send the message as an SSE event + sseResponse.write( + `event: message\ndata: ${JSON.stringify(message)}\n\n`, + ); + + // If this is a response message with the same ID as the request, we can check + // if we need to close the stream after sending the response + if ('result' in message || 'error' in message) { + if (message.id === relatedRequestId) { + // This is a response to the original request, we can close the stream + // after sending all related responses + this._sseResponseMapping.delete(relatedRequestId); + + // Only close the connection if it's not needed by other requests + const canCloseConnection = ![...this._sseResponseMapping.entries()].some(([id, res]) => res === sseResponse && id !== relatedRequestId); + if (canCloseConnection) { + sseResponse.end(); + } + } + } + } + +} \ No newline at end of file diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a6e47184..b072e578 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -78,22 +78,52 @@ export type RequestOptions = { * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; + + /** + * May be used to indicate to the transport which incoming request to associate this outgoing request with. + */ + relatedRequestId?: RequestId; }; /** - * Extra data given to request handlers. + * Options that can be given per notification. */ -export type RequestHandlerExtra = { +export type NotificationOptions = { /** - * An abort signal used to communicate if the request was cancelled from the sender's side. + * May be used to indicate to the transport which incoming request to associate this outgoing notification with. */ - signal: AbortSignal; + relatedRequestId?: RequestId; +} - /** - * The session ID from the transport, if available. - */ - sessionId?: string; -}; +/** + * Extra data given to request handlers. + */ +export type RequestHandlerExtra = { + /** + * An abort signal used to communicate if the request was cancelled from the sender's side. + */ + signal: AbortSignal; + + /** + * The session ID from the transport, if available. + */ + sessionId?: string; + + /** + * Sends a notification that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + sendNotification: (notification: SendNotificationT) => Promise; + + /** + * Sends a request that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + sendRequest: >(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; + }; /** * Information about a request's timeout state @@ -122,7 +152,7 @@ export abstract class Protocol< string, ( request: JSONRPCRequest, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => Promise > = new Map(); private _requestHandlerAbortControllers: Map = @@ -316,9 +346,14 @@ export abstract class Protocol< this._requestHandlerAbortControllers.set(request.id, abortController); // Create extra object with both abort signal and sessionId from transport - const extra: RequestHandlerExtra = { + const extra: RequestHandlerExtra = { signal: abortController.signal, sessionId: this._transport?.sessionId, + sendNotification: + (notification) => + this.notification(notification, { relatedRequestId: request.id }), + sendRequest: (r, resultSchema, options?) => + this.request(r, resultSchema, { ...options, relatedRequestId: request.id }) }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. @@ -364,7 +399,7 @@ export abstract class Protocol< private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); - + const handler = this._progressHandlers.get(messageId); if (!handler) { this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); @@ -373,7 +408,7 @@ export abstract class Protocol< const responseHandler = this._responseHandlers.get(messageId); const timeoutInfo = this._timeoutInfo.get(messageId); - + if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { try { this._resetTimeout(messageId); @@ -460,6 +495,8 @@ export abstract class Protocol< resultSchema: T, options?: RequestOptions, ): Promise> { + const { relatedRequestId } = options ?? {}; + return new Promise((resolve, reject) => { if (!this._transport) { reject(new Error("Not connected")); @@ -500,7 +537,7 @@ export abstract class Protocol< requestId: messageId, reason: String(reason), }, - }) + }, { relatedRequestId }) .catch((error) => this._onerror(new Error(`Failed to send cancellation: ${error}`)), ); @@ -538,7 +575,7 @@ export abstract class Protocol< this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - this._transport.send(jsonrpcRequest).catch((error) => { + this._transport.send(jsonrpcRequest, { relatedRequestId }).catch((error) => { this._cleanupTimeout(messageId); reject(error); }); @@ -548,7 +585,7 @@ export abstract class Protocol< /** * Emits a notification, which is a one-way message that does not expect a response. */ - async notification(notification: SendNotificationT): Promise { + async notification(notification: SendNotificationT, options?: NotificationOptions): Promise { if (!this._transport) { throw new Error("Not connected"); } @@ -560,7 +597,7 @@ export abstract class Protocol< jsonrpc: "2.0", }; - await this._transport.send(jsonrpcNotification); + await this._transport.send(jsonrpcNotification, options); } /** @@ -576,14 +613,15 @@ export abstract class Protocol< requestSchema: T, handler: ( request: z.infer, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => SendResultT | Promise, ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); - this._requestHandlers.set(method, (request, extra) => - Promise.resolve(handler(requestSchema.parse(request), extra)), - ); + + this._requestHandlers.set(method, (request, extra) => { + return Promise.resolve(handler(requestSchema.parse(request), extra)); + }); } /** diff --git a/src/shared/transport.ts b/src/shared/transport.ts index b80e2a51..e464653b 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -1,4 +1,4 @@ -import { JSONRPCMessage } from "../types.js"; +import { JSONRPCMessage, RequestId } from "../types.js"; /** * Describes the minimal contract for a MCP transport that a client or server can communicate over. @@ -15,8 +15,10 @@ export interface Transport { /** * Sends a JSON-RPC message (request or response). + * + * If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with. */ - send(message: JSONRPCMessage): Promise; + send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise; /** * Closes the connection.