@@ -28,7 +28,7 @@ func forceBinaryParameters() bool {
28
28
}
29
29
}
30
30
31
- func openTestConnConninfo (conninfo string ) ( * sql. DB , error ) {
31
+ func testConninfo (conninfo string ) string {
32
32
defaultTo := func (envvar string , value string ) {
33
33
if os .Getenv (envvar ) == "" {
34
34
os .Setenv (envvar , value )
@@ -43,8 +43,11 @@ func openTestConnConninfo(conninfo string) (*sql.DB, error) {
43
43
! strings .HasPrefix (conninfo , "postgresql://" ) {
44
44
conninfo = conninfo + " binary_parameters=yes"
45
45
}
46
+ return conninfo
47
+ }
46
48
47
- return sql .Open ("postgres" , conninfo )
49
+ func openTestConnConninfo (conninfo string ) (* sql.DB , error ) {
50
+ return sql .Open ("postgres" , testConninfo (conninfo ))
48
51
}
49
52
50
53
func openTestConn (t Fatalistic ) * sql.DB {
@@ -637,6 +640,57 @@ func TestErrorDuringStartup(t *testing.T) {
637
640
}
638
641
}
639
642
643
+ type testConn struct {
644
+ closed bool
645
+ net.Conn
646
+ }
647
+
648
+ func (c * testConn ) Close () error {
649
+ c .closed = true
650
+ return c .Conn .Close ()
651
+ }
652
+
653
+ type testDialer struct {
654
+ conns []* testConn
655
+ }
656
+
657
+ func (d * testDialer ) Dial (ntw , addr string ) (net.Conn , error ) {
658
+ c , err := net .Dial (ntw , addr )
659
+ if err != nil {
660
+ return nil , err
661
+ }
662
+ tc := & testConn {Conn : c }
663
+ d .conns = append (d .conns , tc )
664
+ return tc , nil
665
+ }
666
+
667
+ func (d * testDialer ) DialTimeout (ntw , addr string , timeout time.Duration ) (net.Conn , error ) {
668
+ c , err := net .DialTimeout (ntw , addr , timeout )
669
+ if err != nil {
670
+ return nil , err
671
+ }
672
+ tc := & testConn {Conn : c }
673
+ d .conns = append (d .conns , tc )
674
+ return tc , nil
675
+ }
676
+
677
+ func TestErrorDuringStartupClosesConn (t * testing.T ) {
678
+ // Don't use the normal connection setup, this is intended to
679
+ // blow up in the startup packet from a non-existent user.
680
+ var d testDialer
681
+ c , err := DialOpen (& d , testConninfo ("user=thisuserreallydoesntexist" ))
682
+ if err == nil {
683
+ c .Close ()
684
+ t .Fatal ("expected dial error" )
685
+ }
686
+ if len (d .conns ) != 1 {
687
+ t .Fatalf ("got len(d.conns) = %d, want = %d" , len (d .conns ), 1 )
688
+ }
689
+ if ! d .conns [0 ].closed {
690
+ t .Error ("connection leaked" )
691
+ }
692
+ }
693
+
640
694
func TestBadConn (t * testing.T ) {
641
695
var err error
642
696
0 commit comments