Skip to content

Commit af3f569

Browse files
committed
server: prohibit more than MaxConcurrentStreams handlers from running at once (grpc#6703)
1 parent 7511ddf commit af3f569

File tree

5 files changed

+220
-45
lines changed

5 files changed

+220
-45
lines changed

benchmark/primitives/primitives_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,42 @@ func BenchmarkRLockUnlock(b *testing.B) {
425425
}
426426
})
427427
}
428+
429+
type ifNop interface {
430+
nop()
431+
}
432+
433+
type alwaysNop struct{}
434+
435+
func (alwaysNop) nop() {}
436+
437+
type concreteNop struct {
438+
isNop atomic.Bool
439+
i int
440+
}
441+
442+
func (c *concreteNop) nop() {
443+
if c.isNop.Load() {
444+
return
445+
}
446+
c.i++
447+
}
448+
449+
func BenchmarkInterfaceNop(b *testing.B) {
450+
n := ifNop(alwaysNop{})
451+
b.RunParallel(func(pb *testing.PB) {
452+
for pb.Next() {
453+
n.nop()
454+
}
455+
})
456+
}
457+
458+
func BenchmarkConcreteNop(b *testing.B) {
459+
n := &concreteNop{}
460+
n.isNop.Store(true)
461+
b.RunParallel(func(pb *testing.PB) {
462+
for pb.Next() {
463+
n.nop()
464+
}
465+
})
466+
}

internal/transport/http2_server.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
171171
ID: http2.SettingMaxFrameSize,
172172
Val: http2MaxFrameLen,
173173
}}
174-
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
175-
// permitted in the HTTP2 spec.
176-
maxStreams := config.MaxStreams
177-
if maxStreams == 0 {
178-
maxStreams = math.MaxUint32
179-
} else {
174+
if config.MaxStreams != math.MaxUint32 {
180175
isettings = append(isettings, http2.Setting{
181176
ID: http2.SettingMaxConcurrentStreams,
182-
Val: maxStreams,
177+
Val: config.MaxStreams,
183178
})
184179
}
185180
dynamicWindow := true
@@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
258253
framer: framer,
259254
readerDone: make(chan struct{}),
260255
writerDone: make(chan struct{}),
261-
maxStreams: maxStreams,
256+
maxStreams: config.MaxStreams,
262257
inTapHandle: config.InTapHandle,
263258
fc: &trInFlow{limit: uint32(icwz)},
264259
state: reachable,

internal/transport/transport_test.go

+19-16
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
337337
return
338338
}
339339
rawConn := conn
340+
if serverConfig.MaxStreams == 0 {
341+
serverConfig.MaxStreams = math.MaxUint32
342+
}
340343
transport, err := NewServerTransport(conn, serverConfig)
341344
if err != nil {
342345
return
@@ -443,8 +446,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
443446
return server
444447
}
445448

446-
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
447-
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
449+
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
450+
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
448451
}
449452

450453
func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
@@ -539,7 +542,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
539542

540543
// Tests that when streamID > MaxStreamId, the current client transport drains.
541544
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
542-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
545+
server, ct, cancel := setUp(t, 0, normal)
543546
defer cancel()
544547
defer server.stop()
545548
callHdr := &CallHdr{
@@ -584,7 +587,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
584587
}
585588

586589
func (s) TestClientSendAndReceive(t *testing.T) {
587-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
590+
server, ct, cancel := setUp(t, 0, normal)
588591
defer cancel()
589592
callHdr := &CallHdr{
590593
Host: "localhost",
@@ -624,7 +627,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
624627
}
625628

626629
func (s) TestClientErrorNotify(t *testing.T) {
627-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
630+
server, ct, cancel := setUp(t, 0, normal)
628631
defer cancel()
629632
go server.stop()
630633
// ct.reader should detect the error and activate ct.Error().
@@ -658,7 +661,7 @@ func performOneRPC(ct ClientTransport) {
658661
}
659662

660663
func (s) TestClientMix(t *testing.T) {
661-
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
664+
s, ct, cancel := setUp(t, 0, normal)
662665
defer cancel()
663666
time.AfterFunc(time.Second, s.stop)
664667
go func(ct ClientTransport) {
@@ -672,7 +675,7 @@ func (s) TestClientMix(t *testing.T) {
672675
}
673676

674677
func (s) TestLargeMessage(t *testing.T) {
675-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
678+
server, ct, cancel := setUp(t, 0, normal)
676679
defer cancel()
677680
callHdr := &CallHdr{
678681
Host: "localhost",
@@ -807,7 +810,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
807810
// proceed until they complete naturally, while not allowing creation of new
808811
// streams during this window.
809812
func (s) TestGracefulClose(t *testing.T) {
810-
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
813+
server, ct, cancel := setUp(t, 0, pingpong)
811814
defer cancel()
812815
defer func() {
813816
// Stop the server's listener to make the server's goroutines terminate
@@ -873,7 +876,7 @@ func (s) TestGracefulClose(t *testing.T) {
873876
}
874877

875878
func (s) TestLargeMessageSuspension(t *testing.T) {
876-
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
879+
server, ct, cancel := setUp(t, 0, suspended)
877880
defer cancel()
878881
callHdr := &CallHdr{
879882
Host: "localhost",
@@ -981,7 +984,7 @@ func (s) TestMaxStreams(t *testing.T) {
981984
}
982985

983986
func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
984-
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
987+
server, ct, cancel := setUp(t, 0, suspended)
985988
defer cancel()
986989
callHdr := &CallHdr{
987990
Host: "localhost",
@@ -1453,7 +1456,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
14531456
var encodingTestStatus = status.New(codes.Internal, "\n")
14541457

14551458
func (s) TestEncodingRequiredStatus(t *testing.T) {
1456-
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
1459+
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
14571460
defer cancel()
14581461
callHdr := &CallHdr{
14591462
Host: "localhost",
@@ -1481,7 +1484,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
14811484
}
14821485

14831486
func (s) TestInvalidHeaderField(t *testing.T) {
1484-
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
1487+
server, ct, cancel := setUp(t, 0, invalidHeaderField)
14851488
defer cancel()
14861489
callHdr := &CallHdr{
14871490
Host: "localhost",
@@ -1503,7 +1506,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
15031506
}
15041507

15051508
func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
1506-
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
1509+
server, ct, cancel := setUp(t, 0, invalidHeaderField)
15071510
defer cancel()
15081511
defer server.stop()
15091512
defer ct.Close(fmt.Errorf("closed manually by test"))
@@ -2171,7 +2174,7 @@ func (s) TestPingPong1MB(t *testing.T) {
21712174

21722175
// This is a stress-test of flow control logic.
21732176
func runPingPongTest(t *testing.T, msgSize int) {
2174-
server, client, cancel := setUp(t, 0, 0, pingpong)
2177+
server, client, cancel := setUp(t, 0, pingpong)
21752178
defer cancel()
21762179
defer server.stop()
21772180
defer client.Close(fmt.Errorf("closed manually by test"))
@@ -2253,7 +2256,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
22532256
}
22542257
}()
22552258

2256-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
2259+
server, ct, cancel := setUp(t, 0, normal)
22572260
defer cancel()
22582261
defer ct.Close(fmt.Errorf("closed manually by test"))
22592262
defer server.stop()
@@ -2612,7 +2615,7 @@ func TestConnectionError_Unwrap(t *testing.T) {
26122615

26132616
func (s) TestPeerSetInServerContext(t *testing.T) {
26142617
// create client and server transports.
2615-
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
2618+
server, client, cancel := setUp(t, 0, normal)
26162619
defer cancel()
26172620
defer server.stop()
26182621
defer client.Close(fmt.Errorf("closed manually by test"))

server.go

+48-21
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,6 @@ type serviceInfo struct {
115115
mdata interface{}
116116
}
117117

118-
type serverWorkerData struct {
119-
st transport.ServerTransport
120-
wg *sync.WaitGroup
121-
stream *transport.Stream
122-
}
123-
124118
// Server is a gRPC server to serve RPC requests.
125119
type Server struct {
126120
opts serverOptions
@@ -145,7 +139,7 @@ type Server struct {
145139
channelzID *channelz.Identifier
146140
czData *channelzData
147141

148-
serverWorkerChannel chan *serverWorkerData
142+
serverWorkerChannel chan func()
149143
}
150144

151145
type serverOptions struct {
@@ -178,6 +172,7 @@ type serverOptions struct {
178172
}
179173

180174
var defaultServerOptions = serverOptions{
175+
maxConcurrentStreams: math.MaxUint32,
181176
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
182177
maxSendMessageSize: defaultServerMaxSendMessageSize,
183178
connectionTimeout: 120 * time.Second,
@@ -389,6 +384,9 @@ func MaxSendMsgSize(m int) ServerOption {
389384
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
390385
// of concurrent streams to each ServerTransport.
391386
func MaxConcurrentStreams(n uint32) ServerOption {
387+
if n == 0 {
388+
n = math.MaxUint32
389+
}
392390
return newFuncServerOption(func(o *serverOptions) {
393391
o.maxConcurrentStreams = n
394392
})
@@ -590,24 +588,19 @@ const serverWorkerResetThreshold = 1 << 16
590588
// [1] https://github.com/golang/go/issues/18138
591589
func (s *Server) serverWorker() {
592590
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
593-
data, ok := <-s.serverWorkerChannel
591+
f, ok := <-s.serverWorkerChannel
594592
if !ok {
595593
return
596594
}
597-
s.handleSingleStream(data)
595+
f()
598596
}
599597
go s.serverWorker()
600598
}
601599

602-
func (s *Server) handleSingleStream(data *serverWorkerData) {
603-
defer data.wg.Done()
604-
s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream))
605-
}
606-
607600
// initServerWorkers creates worker goroutines and a channel to process incoming
608601
// connections to reduce the time spent overall on runtime.morestack.
609602
func (s *Server) initServerWorkers() {
610-
s.serverWorkerChannel = make(chan *serverWorkerData)
603+
s.serverWorkerChannel = make(chan func())
611604
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
612605
go s.serverWorker()
613606
}
@@ -966,21 +959,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
966959
defer st.Close(errors.New("finished serving streams for the server transport"))
967960
var wg sync.WaitGroup
968961

962+
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
969963
st.HandleStreams(func(stream *transport.Stream) {
970964
wg.Add(1)
965+
966+
streamQuota.acquire()
967+
f := func() {
968+
defer streamQuota.release()
969+
defer wg.Done()
970+
s.handleStream(st, stream, s.traceInfo(st, stream))
971+
}
972+
971973
if s.opts.numServerWorkers > 0 {
972-
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
973974
select {
974-
case s.serverWorkerChannel <- data:
975+
case s.serverWorkerChannel <- f:
975976
return
976977
default:
977978
// If all stream workers are busy, fallback to the default code path.
978979
}
979980
}
980-
go func() {
981-
defer wg.Done()
982-
s.handleStream(st, stream, s.traceInfo(st, stream))
983-
}()
981+
go f()
984982
}, func(ctx context.Context, method string) context.Context {
985983
if !EnableTracing {
986984
return ctx
@@ -2075,3 +2073,32 @@ func validateSendCompressor(name, clientCompressors string) error {
20752073
}
20762074
return fmt.Errorf("client does not support compressor %q", name)
20772075
}
2076+
2077+
// atomicSemaphore implements a blocking, counting semaphore. acquire should be
2078+
// called synchronously; release may be called asynchronously.
2079+
type atomicSemaphore struct {
2080+
n int64 // accessed atomically
2081+
wait chan struct{}
2082+
}
2083+
2084+
func (q *atomicSemaphore) acquire() {
2085+
if atomic.AddInt64(&q.n, -1) < 0 {
2086+
// We ran out of quota. Block until a release happens.
2087+
<-q.wait
2088+
}
2089+
}
2090+
2091+
func (q *atomicSemaphore) release() {
2092+
// N.B. the "<= 0" check below should allow for this to work with multiple
2093+
// concurrent calls to acquire, but also note that with synchronous calls to
2094+
// acquire, as our system does, n will never be less than -1. There are
2095+
// fairness issues (queuing) to consider if this was to be generalized.
2096+
if atomic.AddInt64(&q.n, -1) <= 0 {
2097+
// An acquire was waiting on us. Unblock it.
2098+
q.wait <- struct{}{}
2099+
}
2100+
}
2101+
2102+
func newHandlerQuota(n uint32) *atomicSemaphore {
2103+
return &atomicSemaphore{n: int64(n), wait: make(chan struct{}, 1)}
2104+
}

0 commit comments

Comments
 (0)