Skip to content

Add AcceptSocketAsync to SocketTransportOptions #34345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(th
static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(this Microsoft.AspNetCore.Hosting.IWebHostBuilder! hostBuilder, System.Action<Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions!>! configureOptions) -> Microsoft.AspNetCore.Hosting.IWebHostBuilder!
static Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateDefaultBoundListenSocket(System.Net.EndPoint! endpoint) -> System.Net.Sockets.Socket!
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.get -> System.Func<System.Net.EndPoint!, System.Net.Sockets.Socket!>!
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.set -> void
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.set -> void
static Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.DefaultAcceptSocketAsync(System.Net.Sockets.Socket! listenSocket, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket!>
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.AcceptSocketAsync.get -> System.Func<System.Net.Sockets.Socket!, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket!>>!
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.AcceptSocketAsync.set -> void
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ internal void Bind()
{
Debug.Assert(_listenSocket != null, "Bind must be called first.");

var acceptSocket = await _listenSocket.AcceptAsync(cancellationToken);
var acceptSocket = await _options.AcceptSocketAsync(_listenSocket, cancellationToken);

// Only apply no delay to Tcp based endpoints
if (acceptSocket.LocalEndPoint is IPEndPoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ public class SocketTransportOptions
/// </remarks>
public Func<EndPoint, Socket> CreateBoundListenSocket { get; set; } = CreateDefaultBoundListenSocket;

/// <summary>
/// A function used to accept a new <see cref="Socket"/> given a listening <see cref="Socket"/>.
/// </summary>
/// <remarks>
/// The listening <see cref="Socket"/> passed is the one created by a previous call to <see cref="CreateBoundListenSocket"/>.
///
/// This property defaults to <see cref="DefaultAcceptSocketAsync"/>.
/// </remarks>
public Func<Socket, CancellationToken, ValueTask<Socket>> AcceptSocketAsync { get; set; } = DefaultAcceptSocketAsync;

/// <summary>
/// Accepts a new <see cref="Socket"/> from a listen <see cref="Socket"/> previously obtained from <see cref="CreateBoundListenSocket"/>.
/// </summary>
/// <param name="listenSocket">A listening <see cref="Socket"/>.</param>
/// <param name="cancellationToken">Indicates if the accept operation should be aborted.</param>
/// <returns>A newly accepted <see cref="Socket"/>.</returns>
public static ValueTask<Socket> DefaultAcceptSocketAsync(Socket listenSocket, CancellationToken cancellationToken)
=> listenSocket.AcceptAsync(cancellationToken);

/// <summary>
/// Creates a default instance of <see cref="Socket"/> for the given <see cref="EndPoint"/>
/// that can be used by a connection listener to listen for inbound requests. <see cref="Socket.Bind"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
Expand Down Expand Up @@ -80,6 +81,37 @@ public void CreateDefaultBoundListenSocket_PreservesLocalEndpointFromFileHandleE
Assert.Equal(fileHandleSocket.LocalEndPoint, listenSocket.LocalEndPoint);
}

[Fact]
public async Task VerifySocketTransportCallsAcceptSocketAsync()
{
var wasCalled = false;

ValueTask<Socket> AcceptSocketAsync(Socket socket, CancellationToken cancellationToken)
{
wasCalled = true;
return socket.AcceptAsync(cancellationToken);
}

using var host = CreateWebHost(
new IPEndPoint(IPAddress.Loopback, 0),
options =>
{
options.AcceptSocketAsync = AcceptSocketAsync;
}
);

await host.StartAsync();
using var client = new HttpClient();

var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/");
response.EnsureSuccessStatusCode();

await host.StopAsync();

Assert.True(wasCalled, $"Expected {nameof(SocketTransportOptions.AcceptSocketAsync)} to be called.");
await host.StopAsync();
}

public static IEnumerable<object[]> GetEndpoints()
{
// IPv4
Expand Down