Skip to content

Use Kestrel for all in-memory HTTP tests #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 7, 2025
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Routing.Patterns;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;

namespace Microsoft.AspNetCore.Builder;
Expand All @@ -23,53 +26,87 @@ public static class McpEndpointRouteBuilderExtensions
/// Sets up endpoints for handling MCP HTTP Streaming transport.
/// </summary>
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
/// <param name="runSession">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <param name="pattern">The route pattern prefix to map to.</param>
/// <param name="configureOptionsAsync">Configure per-session options.</param>
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, Func<HttpContext, IMcpServer, CancellationToken, Task>? runSession = null)
public static IEndpointConventionBuilder MapMcp(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern = "",
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
=> endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync);

/// <summary>
/// Sets up endpoints for handling MCP HTTP Streaming transport.
/// </summary>
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
/// <param name="pattern">The route pattern prefix to map to.</param>
/// <param name="configureOptionsAsync">Configure per-session options.</param>
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints,
RoutePattern pattern,
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
{
ConcurrentDictionary<string, SseResponseStreamTransport> _sessions = new(StringComparer.Ordinal);

var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
var optionsFactory = endpoints.ServiceProvider.GetRequiredService<IOptionsFactory<McpServerOptions>>();
var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService<IHostApplicationLifetime>();

var routeGroup = endpoints.MapGroup("");
var routeGroup = endpoints.MapGroup(pattern);

routeGroup.MapGet("/sse", async context =>
{
var response = context.Response;
var requestAborted = context.RequestAborted;
// If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
// which defaults to 30 seconds.
using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
var cancellationToken = sseCts.Token;

var response = context.Response;
response.Headers.ContentType = "text/event-stream";
response.Headers.CacheControl = "no-cache,no-store";

// Make sure we disable all response buffering for SSE
context.Response.Headers.ContentEncoding = "identity";
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();

var sessionId = MakeNewSessionId();
await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}");
if (!_sessions.TryAdd(sessionId, transport))
{
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}

try
var options = optionsSnapshot.Value;
if (configureOptionsAsync is not null)
{
// Make sure we disable all response buffering for SSE
context.Response.Headers.ContentEncoding = "identity";
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();
options = optionsFactory.Create(Options.DefaultName);
await configureOptionsAsync.Invoke(context, options, cancellationToken);
}

var transportTask = transport.RunAsync(cancellationToken: requestAborted);
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
try
{
var transportTask = transport.RunAsync(cancellationToken);

try
{
runSession ??= RunSession;
await runSession(context, server, requestAborted);
await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider);
context.Features.Set(mcpServer);

runSessionAsync ??= RunSession;
await runSessionAsync(context, mcpServer, cancellationToken);
}
finally
{
await transport.DisposeAsync();
await transportTask;
}
}
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// RequestAborted always triggers when the client disconnects before a complete response body is written,
// but this is how SSE connections are typically closed.
Expand Down
5 changes: 4 additions & 1 deletion src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
// Connect transport
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
StartSession(_sessionTransport);
InitializeSession(_sessionTransport);
// We don't want the ConnectAsync token to cancel the session after we've successfully connected.
// The base class handles cleaning up the session in DisposeAsync without our help.
StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None);

// Perform initialization sequence
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,14 @@ private async Task CloseAsync()
{
try
{
if (!_connectionCts.IsCancellationRequested)
{
await _connectionCts.CancelAsync().ConfigureAwait(false);
_connectionCts.Dispose();
}
await _connectionCts.CancelAsync().ConfigureAwait(false);

if (_receiveTask != null)
{
await _receiveTask.ConfigureAwait(false);
}

_connectionCts.Dispose();
}
finally
{
Expand Down
12 changes: 7 additions & 5 deletions src/ModelContextProtocol/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ internal sealed class McpServer : McpEndpoint, IMcpServer
Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0",
};

private readonly ITransport _sessionTransport;

private readonly EventHandler? _toolsChangedDelegate;
private readonly EventHandler? _promptsChangedDelegate;

Expand All @@ -41,6 +43,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?

options ??= new();

_sessionTransport = transport;
ServerOptions = options;
Services = serviceProvider;
_endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})";
Expand Down Expand Up @@ -81,8 +84,8 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
prompts.Changed += _promptsChangedDelegate;
}

// And start the session.
StartSession(transport);
// And initialize the session.
InitializeSession(transport);
}

public ServerCapabilities? ServerCapabilities { get; set; }
Expand Down Expand Up @@ -112,9 +115,8 @@ public async Task RunAsync(CancellationToken cancellationToken = default)

try
{
using var _ = cancellationToken.Register(static s => ((McpServer)s!).CancelSession(), this);
// The McpServer ctor always calls StartSession, so MessageProcessingTask is always set.
await MessageProcessingTask!.ConfigureAwait(false);
StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken);
await MessageProcessingTask.ConfigureAwait(false);
}
finally
{
Expand Down
15 changes: 10 additions & 5 deletions src/ModelContextProtocol/Shared/McpEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;

Expand Down Expand Up @@ -62,12 +63,16 @@ public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcN
/// </summary>
protected Task? MessageProcessingTask { get; private set; }

[MemberNotNull(nameof(MessageProcessingTask))]
protected void StartSession(ITransport sessionTransport)
protected void InitializeSession(ITransport sessionTransport)
{
_sessionCts = new CancellationTokenSource();
_session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger);
MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token);
}

[MemberNotNull(nameof(MessageProcessingTask))]
protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken)
{
_sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken);
MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token);
}

protected void CancelSession() => _sessionCts?.Cancel();
Expand Down Expand Up @@ -122,5 +127,5 @@ public virtual async ValueTask DisposeUnsynchronizedAsync()
}

protected McpSession GetSessionOrThrow()
=> _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(StartSession)} before sending messages.");
=> _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages.");
}
33 changes: 26 additions & 7 deletions tests/ModelContextProtocol.TestSseServer/Program.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Protocol.Types;
using Microsoft.AspNetCore.Connections;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using Serilog;
using System.Text;
Expand Down Expand Up @@ -372,18 +373,34 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
};
}

public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, CancellationToken cancellationToken = default)
public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null, CancellationToken cancellationToken = default)
{
Console.WriteLine("Starting server...");

int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001;

var builder = WebApplication.CreateSlimBuilder(args);
builder.WebHost.ConfigureKestrel(options =>
var builder = WebApplication.CreateEmptyBuilder(new()
{
options.ListenLocalhost(port);
Args = args,
});

if (kestrelTransport is null)
{
int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001;
builder.WebHost.ConfigureKestrel(options =>
{
options.ListenLocalhost(port);
});
}
else
{
// Add passed-in transport before calling UseKestrelCore() to avoid the SocketsHttpHandler getting added.
builder.Services.AddSingleton(kestrelTransport);
}

builder.WebHost.UseKestrelCore();
builder.Services.AddLogging();
builder.Services.AddRoutingCore();

builder.Logging.AddConsole();
ConfigureSerilog(builder.Logging);
if (loggerProvider is not null)
{
Expand All @@ -393,6 +410,8 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide
builder.Services.AddMcpServer(ConfigureOptions);

var app = builder.Build();
app.UseRouting();
app.UseEndpoints(_ => { });

app.MapMcp();

Expand Down
Loading