|
7 | 7 | package http
|
8 | 8 |
|
9 | 9 | import (
|
| 10 | + "bytes" |
| 11 | + "crypto/tls" |
10 | 12 | "errors"
|
| 13 | + "io" |
| 14 | + "io/ioutil" |
11 | 15 | "net"
|
| 16 | + "net/http/internal" |
12 | 17 | "strings"
|
13 | 18 | "testing"
|
14 | 19 | )
|
@@ -178,3 +183,81 @@ func TestTransportShouldRetryRequest(t *testing.T) {
|
178 | 183 | }
|
179 | 184 | }
|
180 | 185 | }
|
| 186 | + |
| 187 | +type roundTripFunc func(r *Request) (*Response, error) |
| 188 | + |
| 189 | +func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { |
| 190 | + return f(r) |
| 191 | +} |
| 192 | + |
| 193 | +// Issue 25009 |
| 194 | +func TestTransportBodyAltRewind(t *testing.T) { |
| 195 | + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) |
| 196 | + if err != nil { |
| 197 | + t.Fatal(err) |
| 198 | + } |
| 199 | + ln := newLocalListener(t) |
| 200 | + defer ln.Close() |
| 201 | + |
| 202 | + go func() { |
| 203 | + tln := tls.NewListener(ln, &tls.Config{ |
| 204 | + NextProtos: []string{"foo"}, |
| 205 | + Certificates: []tls.Certificate{cert}, |
| 206 | + }) |
| 207 | + for i := 0; i < 2; i++ { |
| 208 | + sc, err := tln.Accept() |
| 209 | + if err != nil { |
| 210 | + t.Error(err) |
| 211 | + return |
| 212 | + } |
| 213 | + if err := sc.(*tls.Conn).Handshake(); err != nil { |
| 214 | + t.Error(err) |
| 215 | + return |
| 216 | + } |
| 217 | + sc.Close() |
| 218 | + } |
| 219 | + }() |
| 220 | + |
| 221 | + addr := ln.Addr().String() |
| 222 | + req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request")) |
| 223 | + roundTripped := false |
| 224 | + tr := &Transport{ |
| 225 | + DisableKeepAlives: true, |
| 226 | + TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ |
| 227 | + "foo": func(authority string, c *tls.Conn) RoundTripper { |
| 228 | + return roundTripFunc(func(r *Request) (*Response, error) { |
| 229 | + n, _ := io.Copy(ioutil.Discard, r.Body) |
| 230 | + if n == 0 { |
| 231 | + t.Error("body length is zero") |
| 232 | + } |
| 233 | + if roundTripped { |
| 234 | + return &Response{ |
| 235 | + Body: NoBody, |
| 236 | + StatusCode: 200, |
| 237 | + }, nil |
| 238 | + } |
| 239 | + roundTripped = true |
| 240 | + return nil, http2noCachedConnError{} |
| 241 | + }) |
| 242 | + }, |
| 243 | + }, |
| 244 | + DialTLS: func(_, _ string) (net.Conn, error) { |
| 245 | + tc, err := tls.Dial("tcp", addr, &tls.Config{ |
| 246 | + InsecureSkipVerify: true, |
| 247 | + NextProtos: []string{"foo"}, |
| 248 | + }) |
| 249 | + if err != nil { |
| 250 | + return nil, err |
| 251 | + } |
| 252 | + if err := tc.Handshake(); err != nil { |
| 253 | + return nil, err |
| 254 | + } |
| 255 | + return tc, nil |
| 256 | + }, |
| 257 | + } |
| 258 | + c := &Client{Transport: tr} |
| 259 | + _, err = c.Do(req) |
| 260 | + if err != nil { |
| 261 | + t.Error(err) |
| 262 | + } |
| 263 | +} |
0 commit comments