Skip to content

Commit 61fe37a

Browse files
authored
Merge pull request #698 from iangudger/master
Fix connection leak on conn.ssl or conn.startup failure
2 parents 27ea5d9 + 5253e15 commit 61fe37a

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

conn.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,15 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
339339
if err != nil {
340340
return nil, err
341341
}
342+
343+
// cn.ssl and cn.startup panic on error. Make sure we don't leak cn.c.
344+
panicking := true
345+
defer func() {
346+
if panicking {
347+
cn.c.Close()
348+
}
349+
}()
350+
342351
cn.ssl(o)
343352
cn.buf = bufio.NewReader(cn.c)
344353
cn.startup(o)
@@ -347,6 +356,7 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
347356
if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
348357
err = cn.c.SetDeadline(time.Time{})
349358
}
359+
panicking = false
350360
return cn, err
351361
}
352362

conn_test.go

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func forceBinaryParameters() bool {
2828
}
2929
}
3030

31-
func openTestConnConninfo(conninfo string) (*sql.DB, error) {
31+
func testConninfo(conninfo string) string {
3232
defaultTo := func(envvar string, value string) {
3333
if os.Getenv(envvar) == "" {
3434
os.Setenv(envvar, value)
@@ -43,8 +43,11 @@ func openTestConnConninfo(conninfo string) (*sql.DB, error) {
4343
!strings.HasPrefix(conninfo, "postgresql://") {
4444
conninfo = conninfo + " binary_parameters=yes"
4545
}
46+
return conninfo
47+
}
4648

47-
return sql.Open("postgres", conninfo)
49+
func openTestConnConninfo(conninfo string) (*sql.DB, error) {
50+
return sql.Open("postgres", testConninfo(conninfo))
4851
}
4952

5053
func openTestConn(t Fatalistic) *sql.DB {
@@ -637,6 +640,57 @@ func TestErrorDuringStartup(t *testing.T) {
637640
}
638641
}
639642

643+
type testConn struct {
644+
closed bool
645+
net.Conn
646+
}
647+
648+
func (c *testConn) Close() error {
649+
c.closed = true
650+
return c.Conn.Close()
651+
}
652+
653+
type testDialer struct {
654+
conns []*testConn
655+
}
656+
657+
func (d *testDialer) Dial(ntw, addr string) (net.Conn, error) {
658+
c, err := net.Dial(ntw, addr)
659+
if err != nil {
660+
return nil, err
661+
}
662+
tc := &testConn{Conn: c}
663+
d.conns = append(d.conns, tc)
664+
return tc, nil
665+
}
666+
667+
func (d *testDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
668+
c, err := net.DialTimeout(ntw, addr, timeout)
669+
if err != nil {
670+
return nil, err
671+
}
672+
tc := &testConn{Conn: c}
673+
d.conns = append(d.conns, tc)
674+
return tc, nil
675+
}
676+
677+
func TestErrorDuringStartupClosesConn(t *testing.T) {
678+
// Don't use the normal connection setup, this is intended to
679+
// blow up in the startup packet from a non-existent user.
680+
var d testDialer
681+
c, err := DialOpen(&d, testConninfo("user=thisuserreallydoesntexist"))
682+
if err == nil {
683+
c.Close()
684+
t.Fatal("expected dial error")
685+
}
686+
if len(d.conns) != 1 {
687+
t.Fatalf("got len(d.conns) = %d, want = %d", len(d.conns), 1)
688+
}
689+
if !d.conns[0].closed {
690+
t.Error("connection leaked")
691+
}
692+
}
693+
640694
func TestBadConn(t *testing.T) {
641695
var err error
642696

0 commit comments

Comments
 (0)