From 6ecc178886f367285eb2ae73b5d0719bd4067d24 Mon Sep 17 00:00:00 2001 From: marselester Date: Tue, 29 Sep 2020 22:39:36 -0400 Subject: [PATCH 1/4] Add keepalives support --- conn.go | 2 ++ connector.go | 39 ++++++++++++++++++++++++-- connector_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++ doc.go | 6 ++++ 4 files changed, 115 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index f313c149..000cad90 100644 --- a/conn.go +++ b/conn.go @@ -1068,6 +1068,8 @@ func isDriverSetting(key string) bool { return true case "fallback_application_name": return true + case "keepalives", "keepalives_interval": + return true case "connect_timeout": return true case "disable_prepared_binary_result": diff --git a/connector.go b/connector.go index d7d47261..ecfcc09a 100644 --- a/connector.go +++ b/connector.go @@ -5,8 +5,11 @@ import ( "database/sql/driver" "errors" "fmt" + "net" "os" + "strconv" "strings" + "time" ) // Connector represents a fixed configuration for the pq driver with a given @@ -107,9 +110,41 @@ func NewConnector(dsn string) (*Connector, error) { } // SSL is not necessary or supported over UNIX domain sockets - if network, _ := network(o); network == "unix" { + ntw, _ := network(o) + if ntw == "unix" { o["sslmode"] = "disable" } - return &Connector{opts: o, dialer: defaultDialer{}}, nil + var d net.Dialer + if ntw == "tcp" { + d.KeepAlive, err = keepalive(o) + if err != nil { + return nil, err + } + } + + return &Connector{opts: o, dialer: defaultDialer{d}}, nil +} + +// keepalive returns the interval between keep-alive probes controlled by keepalives_interval. +// If zero, keep-alive probes are sent with a default value (see net.Dialer). +// If negative, keep-alive probes are disabled. +// +// The keepalives parameter controls whether client-side TCP keepalives are used. +// The default value is 1, meaning on, but you can change this to 0, meaning off, if keepalives are not wanted. +func keepalive(o values) (time.Duration, error) { + v, ok := o["keepalives"] + if ok && v == "0" { + return -1, nil + } + + if v, ok = o["keepalives_interval"]; !ok { + return 0, nil + } + + keepintvl, err := strconv.ParseInt(v, 10, 0) + if err != nil { + return 0, fmt.Errorf("invalid value for parameter keepalives_interval: %w", err) + } + return time.Duration(keepintvl) * time.Second, nil } diff --git a/connector_test.go b/connector_test.go index 3d2c67b0..054a06ce 100644 --- a/connector_test.go +++ b/connector_test.go @@ -6,7 +6,10 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" + "strconv" "testing" + "time" ) func TestNewConnector_WorksWithOpenDB(t *testing.T) { @@ -65,3 +68,70 @@ func TestNewConnector_Driver(t *testing.T) { } txn.Rollback() } + +func TestNewConnectorKeepalive(t *testing.T) { + c, err := NewConnector("keepalives=1 keepalives_interval=10") + if err != nil { + t.Fatal(err) + } + db := sql.OpenDB(c) + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + txn.Rollback() + + d, _ := c.dialer.(defaultDialer) + want := 10 * time.Second + if want != d.d.KeepAlive { + t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive) + } +} + +func TestKeepalive(t *testing.T) { + var tt = map[string]struct { + input values + want time.Duration + }{ + "keepalives on": {values{"keepalives": "1"}, 0}, + "keepalives on by default": {nil, 0}, + "keepalives off": {values{"keepalives": "0"}, -1}, + "keepalives_interval 5 seconds": {values{"keepalives_interval": "5"}, 5 * time.Second}, + "keepalives_interval default": {values{"keepalives_interval": "0"}, 0}, + "keepalives_interval off": {values{"keepalives_interval": "-1"}, -1 * time.Second}, + } + + for name, tc := range tt { + t.Run(name, func(t *testing.T) { + got, err := keepalive(tc.input) + if err != nil { + t.Fatal(err) + } + if tc.want != got { + t.Fatalf("expected: %v, got: %v", tc.want, got) + } + }) + } +} + +func TestKeepaliveError(t *testing.T) { + var tt = map[string]struct { + input values + want error + }{ + "keepalives_interval whitespace": {values{"keepalives_interval": " "}, strconv.ErrSyntax}, + "keepalives_interval float": {values{"keepalives_interval": "1.1"}, strconv.ErrSyntax}, + } + + for name, tc := range tt { + t.Run(name, func(t *testing.T) { + _, err := keepalive(tc.input) + if !errors.Is(err, tc.want) { + t.Fatalf("expected: %v, got: %v", tc.want, err) + } + }) + } +} diff --git a/doc.go b/doc.go index b5718480..246d5ffd 100644 --- a/doc.go +++ b/doc.go @@ -51,6 +51,12 @@ supported: * sslmode - Whether or not to use SSL (default is require, this is not the default for libpq) * fallback_application_name - An application_name to fall back to if one isn't provided. + * keepalives - Whether or not to use client-side TCP keepalives + (the default value is 1, meaning on, but you can change this to 0, meaning off) + * keepalives_interval - The number of seconds after which a TCP keepalive message + that is not acknowledged by the server should be retransmitted. + If zero or not specified, keep-alive probes are sent with a default value (see net.Dialer). + If negative, keep-alive probes are disabled. * connect_timeout - Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. * sslcert - Cert file location. The file must contain PEM encoded data. From d9d4371dd431493406ad7add868301b802a95d7d Mon Sep 17 00:00:00 2001 From: marselester Date: Wed, 30 Sep 2020 00:12:33 -0400 Subject: [PATCH 2/4] Rewrite tests for Go 1.13 compatibility --- connector.go | 2 +- connector_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/connector.go b/connector.go index ecfcc09a..ed958351 100644 --- a/connector.go +++ b/connector.go @@ -144,7 +144,7 @@ func keepalive(o values) (time.Duration, error) { keepintvl, err := strconv.ParseInt(v, 10, 0) if err != nil { - return 0, fmt.Errorf("invalid value for parameter keepalives_interval: %w", err) + return 0, fmt.Errorf("invalid value for parameter keepalives_interval: %v", err) } return time.Duration(keepintvl) * time.Second, nil } diff --git a/connector_test.go b/connector_test.go index 054a06ce..9067f720 100644 --- a/connector_test.go +++ b/connector_test.go @@ -6,8 +6,8 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" "strconv" + "strings" "testing" "time" ) @@ -129,7 +129,7 @@ func TestKeepaliveError(t *testing.T) { for name, tc := range tt { t.Run(name, func(t *testing.T) { _, err := keepalive(tc.input) - if !errors.Is(err, tc.want) { + if !strings.HasSuffix(err.Error(), tc.want.Error()) { t.Fatalf("expected: %v, got: %v", tc.want, err) } }) From 51502e0e0b6a7fc8467f8a9d336c992bbb6d3989 Mon Sep 17 00:00:00 2001 From: marselester Date: Wed, 30 Sep 2020 19:46:44 -0400 Subject: [PATCH 3/4] Add OpenConnector to Driver to satisfy DriverContext interface sql.Open will use driver's OpenConnector instead of its Open method. This will enable TCP keepalive support. Note, testDriver was added for TestRuntimeParameters since it expects old fashioned Driver interface (without OpenConnector). --- conn.go | 6 ++++++ conn_test.go | 13 ++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 000cad90..63665156 100644 --- a/conn.go +++ b/conn.go @@ -48,6 +48,12 @@ func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } +// OpenConnector parses the name in the same format that Driver.Open +// parses the name parameter. +func (d *Driver) OpenConnector(name string) (driver.Connector, error) { + return NewConnector(name) +} + func init() { sql.Register("postgres", &Driver{}) } diff --git a/conn_test.go b/conn_test.go index 0d25c955..a6b6a3e5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -14,6 +14,17 @@ import ( "time" ) +// testDriver is the Postgres database driver that doesn't implement DriverContext interface. +type testDriver struct{} + +func (d *testDriver) Open(name string) (driver.Conn, error) { + return Open(name) +} + +func init() { + sql.Register("postgres-test", &testDriver{}) +} + type Fatalistic interface { Fatal(args ...interface{}) } @@ -48,7 +59,7 @@ func testConninfo(conninfo string) string { } func openTestConnConninfo(conninfo string) (*sql.DB, error) { - return sql.Open("postgres", testConninfo(conninfo)) + return sql.Open("postgres-test", testConninfo(conninfo)) } func openTestConn(t Fatalistic) *sql.DB { From 5f9a0c9107af01eeca5f19b957065570c6333bf9 Mon Sep 17 00:00:00 2001 From: marselester Date: Thu, 1 Oct 2020 11:54:48 -0400 Subject: [PATCH 4/4] Update Open func to support keepalive Got rid of OpenConnector for simplicity's sake. Noticed "pq: current transaction is aborted" in go18_test.go:212, but it doesn't seem to be related. The same issue popped up in https://github.com/lib/pq/pull/921. --- conn.go | 12 +++++------- conn_test.go | 53 ++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/conn.go b/conn.go index 63665156..571de61d 100644 --- a/conn.go +++ b/conn.go @@ -48,12 +48,6 @@ func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } -// OpenConnector parses the name in the same format that Driver.Open -// parses the name parameter. -func (d *Driver) OpenConnector(name string) (driver.Connector, error) { - return NewConnector(name) -} - func init() { sql.Register("postgres", &Driver{}) } @@ -278,7 +272,11 @@ func (cn *conn) writeBuf(b byte) *writeBuf { // Most users should only use it through database/sql package from the standard // library. func Open(dsn string) (_ driver.Conn, err error) { - return DialOpen(defaultDialer{}, dsn) + c, err := NewConnector(dsn) + if err != nil { + return nil, err + } + return c.open(context.Background()) } // DialOpen opens a new connection to the database using a dialer. diff --git a/conn_test.go b/conn_test.go index a6b6a3e5..82b18d12 100644 --- a/conn_test.go +++ b/conn_test.go @@ -14,17 +14,6 @@ import ( "time" ) -// testDriver is the Postgres database driver that doesn't implement DriverContext interface. -type testDriver struct{} - -func (d *testDriver) Open(name string) (driver.Conn, error) { - return Open(name) -} - -func init() { - sql.Register("postgres-test", &testDriver{}) -} - type Fatalistic interface { Fatal(args ...interface{}) } @@ -59,7 +48,7 @@ func testConninfo(conninfo string) string { } func openTestConnConninfo(conninfo string) (*sql.DB, error) { - return sql.Open("postgres-test", testConninfo(conninfo)) + return sql.Open("postgres", testConninfo(conninfo)) } func openTestConn(t Fatalistic) *sql.DB { @@ -151,6 +140,46 @@ func TestOpenURL(t *testing.T) { testURL("postgresql://") } +func TestOpen(t *testing.T) { + dsn := "keepalives_interval=10" + c, err := Open(dsn) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + d := c.(*conn).dialer.(defaultDialer) + want := 10 * time.Second + if want != d.d.KeepAlive { + t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive) + } +} + +func TestSQLOpen(t *testing.T) { + dsn := "keepalives_interval=10" + db, err := sql.Open("postgres", dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + if err = db.Ping(); err != nil { + t.Fatal(err) + } + + drv := db.Driver() + c, err := drv.Open(dsn) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + d := c.(*conn).dialer.(defaultDialer) + want := 10 * time.Second + if want != d.d.KeepAlive { + t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive) + } +} + const pgpassFile = "/tmp/pqgotest_pgpass" func TestPgpass(t *testing.T) {