diff --git a/AUTHORS b/AUTHORS index 0a22d46e7..5f27aa457 100644 --- a/AUTHORS +++ b/AUTHORS @@ -14,6 +14,7 @@ Aaron Hopkins Arne Hormann Asta Xie +Bulat Gaifullin Carlos Nieto Chris Moos Daniel Nichter @@ -21,11 +22,13 @@ Daniël van Eeden Dave Protasowski DisposaBoy Egor Smolyakov +Evan Shaw Frederick Mayle Gustavo Kristic Hanno Braun Henri Yandell Hirotaka Yamamoto +ICHINOSE Shogo INADA Naoki Jacek Szwec James Harr @@ -45,6 +48,7 @@ Luke Scott Michael Woolnough Nicola Peduzzi Olivier Mengué +oscarzhao Paul Bonser Peter Schultz Rebecca Chin diff --git a/README.md b/README.md index 76a35d32e..d39b29c90 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) * [time.Time support](#timetime-support) * [Unicode support](#unicode-support) + * [context.Context Support](#contextcontext-support) * [Testing / Development](#testing--development) * [License](#license) @@ -443,6 +444,9 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. +## `context.Context` Support +Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. +See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. diff --git a/benchmark_go18_test.go b/benchmark_go18_test.go new file mode 100644 index 000000000..d6a7e9d6e --- /dev/null +++ b/benchmark_go18_test.go @@ -0,0 +1,93 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "context" + "database/sql" + "fmt" + "runtime" + "testing" +) + +func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + var got string + for pb.Next() { + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Fatalf("query = %q; want one", got) + } + } + }) +} + +func BenchmarkQueryContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := stmt.ExecContext(ctx); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkExecContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} diff --git a/connection.go b/connection.go index cdce3e30f..89a4f464a 100644 --- a/connection.go +++ b/connection.go @@ -14,9 +14,21 @@ import ( "net" "strconv" "strings" + "sync" + "sync/atomic" "time" ) +// a copy of context.Context for Go 1.7 and later. +type mysqlContext interface { + Done() <-chan struct{} + Err() error + + // They are defined in context.Context, but go-mysql-driver does not use them. + // Deadline() (deadline time.Time, ok bool) + // Value(key interface{}) interface{} +} + type mysqlConn struct { buf buffer netConn net.Conn @@ -31,6 +43,19 @@ type mysqlConn struct { sequence uint8 parseTime bool strict bool + + // for context support (From Go 1.8) + watching bool + watcher chan<- mysqlContext + closech chan struct{} + finished chan<- struct{} + + // set non-zero when conn is closed, before closech is closed. + // accessed atomically. + closed int32 + + mu sync.Mutex // guards following fields + canceledErr error // set non-nil if conn is canceled } // Handles parameters set in DSN after the connection is established @@ -64,7 +89,7 @@ func (mc *mysqlConn) handleParams() (err error) { } func (mc *mysqlConn) Begin() (driver.Tx, error) { - if mc.netConn == nil { + if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -78,7 +103,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if mc.netConn != nil { + if !mc.isBroken() { err = mc.writeCommandPacket(comQuit) } @@ -92,19 +117,36 @@ func (mc *mysqlConn) Close() (err error) { // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { + if atomic.SwapInt32(&mc.closed, 1) != 0 { + return + } + // Makes cleanup idempotent - if mc.netConn != nil { - if err := mc.netConn.Close(); err != nil { - errLog.Print(err) + close(mc.closech) + if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } +} + +func (mc *mysqlConn) isBroken() bool { + return atomic.LoadInt32(&mc.closed) != 0 +} + +func (mc *mysqlConn) error() error { + if mc.isBroken() { + if err := mc.canceled(); err != nil { + return err } - mc.netConn = nil + return ErrInvalidConn } - mc.cfg = nil - mc.buf.nc = nil + return nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.netConn == nil { + if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -258,7 +300,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.netConn == nil { + if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -315,7 +357,11 @@ func (mc *mysqlConn) exec(query string) error { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if mc.netConn == nil { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -387,3 +433,30 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } return nil, err } + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.mu.Lock() + mc.canceledErr = err + mc.mu.Unlock() + mc.cleanup() +} + +// canceled returns non-nil if the connection was closed due to context cancelation. +func (mc *mysqlConn) canceled() error { + mc.mu.Lock() + defer mc.mu.Unlock() + return mc.canceledErr +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if !mc.watching || mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + mc.watching = false + case <-mc.closech: + } +} diff --git a/connection_go18.go b/connection_go18.go new file mode 100644 index 000000000..384603a9e --- /dev/null +++ b/connection_go18.go @@ -0,0 +1,207 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" +) + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) error { + if mc.isBroken() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err := mc.watchCancel(ctx); err != nil { + return err + } + defer mc.finish() + + if err := mc.writeCommandPacket(comPing); err != nil { + return err + } + if _, err := mc.readResultOK(); err != nil { + return err + } + + return nil +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + // TODO: support isolation levels + return nil, errors.New("mysql: isolation levels not supported") + } + if opts.ReadOnly { + // TODO: support read-only transactions + return nil, errors.New("mysql: read-only transactions not supported") + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + tx, err := mc.Begin() + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + tx.Rollback() + return nil, ctx.Err() + } + return tx, err +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + if ctx.Done() == nil { + return nil + } + + mc.watching = true + select { + default: + case <-ctx.Done(): + return ctx.Err() + } + if mc.watcher == nil { + return nil + } + + mc.watcher <- ctx + + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan mysqlContext, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx mysqlContext + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} diff --git a/driver.go b/driver.go index e51d98a3c..f11e14462 100644 --- a/driver.go +++ b/driver.go @@ -22,6 +22,11 @@ import ( "net" ) +// watcher interface is used for context support (From Go 1.8) +type watcher interface { + startWatcher() +} + // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct{} @@ -52,6 +57,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), } mc.cfg, err = ParseDSN(dsn) if err != nil { @@ -60,6 +66,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime mc.strict = mc.cfg.Strict + // Call startWatcher for context support (From Go 1.8) + if s, ok := interface{}(mc).(watcher); ok { + s.startWatcher() + } + // Connect to Server if dial, ok := dials[mc.cfg.Net]; ok { mc.netConn, err = dial(mc.cfg.Addr) diff --git a/driver_go18_test.go b/driver_go18_test.go index 5a5fa10ff..69d0a2b7d 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -11,10 +11,28 @@ package mysql import ( + "context" "database/sql" + "database/sql/driver" "fmt" "reflect" "testing" + "time" +) + +// static interface implementation checks of mysqlConn +var ( + _ driver.ConnBeginTx = &mysqlConn{} + _ driver.ConnPrepareContext = &mysqlConn{} + _ driver.ExecerContext = &mysqlConn{} + _ driver.Pinger = &mysqlConn{} + _ driver.QueryerContext = &mysqlConn{} +) + +// static interface implementation checks of mysqlStmt +var ( + _ driver.StmtExecContext = &mysqlStmt{} + _ driver.StmtQueryContext = &mysqlStmt{} ) func TestMultiResultSet(t *testing.T) { @@ -196,3 +214,257 @@ func TestSkipResults(t *testing.T) { } }) } + +func TestPingContext(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := dbt.db.PingContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Errorf("expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + dbt.Error("expected error") + } else if err.Error() != "context canceled" { + dbt.Fatalf("unexpected error: %s", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Errorf("expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQueryRow(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test") + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + // the first row will be succeed. + var v int + if !rows.Next() { + dbt.Fatalf("unexpected end") + } + if err := rows.Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + + cancel() + // make sure the driver recieve cancel request. + time.Sleep(100 * time.Millisecond) + + if rows.Next() { + dbt.Errorf("expected end, but not") + } + if err := rows.Err(); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelPrepare(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelStmtExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.ExecContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelStmtQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.QueryContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelBegin(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + tx, err := dbt.db.BeginTx(ctx, nil) + if err != nil { + dbt.Fatal(err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Transaction is canceled, so expect an error. + switch err := tx.Commit(); err { + case sql.ErrTxDone: + // because the transaction has already been rollbacked. + // the database/sql package watches ctx + // and rollbacks when ctx is canceled. + case context.Canceled: + // the database/sql package rollbacks on another goroutine, + // so the transaction may not be rollbacked depending on goroutine scheduling. + default: + dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) + } + + // Context is canceled, so cannot begin a transaction. + if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} diff --git a/driver_test.go b/driver_test.go index 6ca5434a9..206e07cc9 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1991,3 +1991,11 @@ func TestRejectReadOnly(t *testing.T) { dbt.mustExec("DROP TABLE test") }) } + +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} diff --git a/packets.go b/packets.go index 303405a17..0bc120c0e 100644 --- a/packets.go +++ b/packets.go @@ -30,6 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet header data, err := mc.buf.readNext(4) if err != nil { + if cerr := mc.canceled(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() return nil, driver.ErrBadConn @@ -63,6 +66,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + if cerr := mc.canceled(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() return nil, driver.ErrBadConn @@ -125,8 +131,13 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) + mc.cleanup() errLog.Print(ErrMalformPkt) } else { + if cerr := mc.canceled(); cerr != nil { + return cerr + } + mc.cleanup() errLog.Print(err) } return driver.ErrBadConn diff --git a/packets_test.go b/packets_test.go index b1d64f5c7..31c892d85 100644 --- a/packets_test.go +++ b/packets_test.go @@ -244,7 +244,8 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + buf: newBuffer(conn), + closech: make(chan struct{}), } // illegal empty (stand-alone) packet diff --git a/rows.go b/rows.go index 13905e216..c7f5ee26c 100644 --- a/rows.go +++ b/rows.go @@ -28,8 +28,9 @@ type resultSet struct { } type mysqlRows struct { - mc *mysqlConn - rs resultSet + mc *mysqlConn + rs resultSet + finish func() } type binaryRows struct { @@ -65,12 +66,17 @@ func (rows *mysqlRows) Columns() []string { } func (rows *mysqlRows) Close() (err error) { + if f := rows.finish; f != nil { + f() + rows.finish = nil + } + mc := rows.mc if mc == nil { return nil } - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Remove unread packets from stream @@ -98,8 +104,8 @@ func (rows *mysqlRows) nextResultSet() (int, error) { if rows.mc == nil { return 0, io.EOF } - if rows.mc.netConn == nil { - return 0, ErrInvalidConn + if err := rows.mc.error(); err != nil { + return 0, err } // Remove unread packets from stream @@ -145,8 +151,8 @@ func (rows *binaryRows) NextResultSet() error { func (rows *binaryRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream @@ -167,8 +173,8 @@ func (rows *textRows) NextResultSet() (err error) { func (rows *textRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream diff --git a/statement.go b/statement.go index e5071276a..5855f943f 100644 --- a/statement.go +++ b/statement.go @@ -23,7 +23,7 @@ type mysqlStmt struct { } func (stmt *mysqlStmt) Close() error { - if stmt.mc == nil || stmt.mc.netConn == nil { + if stmt.mc == nil || stmt.mc.isBroken() { // driver.Stmt.Close can be called more than once, thus this function // has to be idempotent. // See also Issue #450 and golang/go#16019. @@ -45,7 +45,7 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.netConn == nil { + if stmt.mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -89,7 +89,11 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - if stmt.mc.netConn == nil { + return stmt.query(args) +} + +func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { + if stmt.mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } diff --git a/transaction.go b/transaction.go index 33c749b35..5d88c0399 100644 --- a/transaction.go +++ b/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.isBroken() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.isBroken() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") diff --git a/utils_go18.go b/utils_go18.go index 2aa9d0f18..eaeac4f84 100644 --- a/utils_go18.go +++ b/utils_go18.go @@ -10,8 +10,24 @@ package mysql -import "crypto/tls" +import ( + "crypto/tls" + "database/sql/driver" + "errors" +) func cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +}