Skip to content

Commit 3aa3c05

Browse files
ernadobradfitz
authored andcommitted
net/http: rewind request body unconditionally
When http2 fails with ErrNoCachedConn the request is retried with body that has already been read. Fixes #25009 Change-Id: I51ed5c8cf469dd8b17c73fff6140ab80162bf267 Reviewed-on: https://go-review.googlesource.com/c/131755 Run-TryBot: Iskander Sharipov <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 0906d64 commit 3aa3c05

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

src/net/http/transport.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
478478
}
479479
testHookRoundTripRetried()
480480

481-
// Rewind the body if we're able to. (HTTP/2 does this itself so we only
482-
// need to do it for HTTP/1.1 connections.)
483-
if req.GetBody != nil && pconn.alt == nil {
481+
// Rewind the body if we're able to.
482+
if req.GetBody != nil {
484483
newReq := *req
485484
var err error
486485
newReq.Body, err = req.GetBody()

src/net/http/transport_internal_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
package http
88

99
import (
10+
"bytes"
11+
"crypto/tls"
1012
"errors"
13+
"io"
14+
"io/ioutil"
1115
"net"
16+
"net/http/internal"
1217
"strings"
1318
"testing"
1419
)
@@ -178,3 +183,81 @@ func TestTransportShouldRetryRequest(t *testing.T) {
178183
}
179184
}
180185
}
186+
187+
type roundTripFunc func(r *Request) (*Response, error)
188+
189+
func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
190+
return f(r)
191+
}
192+
193+
// Issue 25009
194+
func TestTransportBodyAltRewind(t *testing.T) {
195+
cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
196+
if err != nil {
197+
t.Fatal(err)
198+
}
199+
ln := newLocalListener(t)
200+
defer ln.Close()
201+
202+
go func() {
203+
tln := tls.NewListener(ln, &tls.Config{
204+
NextProtos: []string{"foo"},
205+
Certificates: []tls.Certificate{cert},
206+
})
207+
for i := 0; i < 2; i++ {
208+
sc, err := tln.Accept()
209+
if err != nil {
210+
t.Error(err)
211+
return
212+
}
213+
if err := sc.(*tls.Conn).Handshake(); err != nil {
214+
t.Error(err)
215+
return
216+
}
217+
sc.Close()
218+
}
219+
}()
220+
221+
addr := ln.Addr().String()
222+
req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
223+
roundTripped := false
224+
tr := &Transport{
225+
DisableKeepAlives: true,
226+
TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
227+
"foo": func(authority string, c *tls.Conn) RoundTripper {
228+
return roundTripFunc(func(r *Request) (*Response, error) {
229+
n, _ := io.Copy(ioutil.Discard, r.Body)
230+
if n == 0 {
231+
t.Error("body length is zero")
232+
}
233+
if roundTripped {
234+
return &Response{
235+
Body: NoBody,
236+
StatusCode: 200,
237+
}, nil
238+
}
239+
roundTripped = true
240+
return nil, http2noCachedConnError{}
241+
})
242+
},
243+
},
244+
DialTLS: func(_, _ string) (net.Conn, error) {
245+
tc, err := tls.Dial("tcp", addr, &tls.Config{
246+
InsecureSkipVerify: true,
247+
NextProtos: []string{"foo"},
248+
})
249+
if err != nil {
250+
return nil, err
251+
}
252+
if err := tc.Handshake(); err != nil {
253+
return nil, err
254+
}
255+
return tc, nil
256+
},
257+
}
258+
c := &Client{Transport: tr}
259+
_, err = c.Do(req)
260+
if err != nil {
261+
t.Error(err)
262+
}
263+
}

0 commit comments

Comments
 (0)