diff --git a/connection.go b/connection.go index 911be2060..f74235519 100644 --- a/connection.go +++ b/connection.go @@ -595,22 +595,21 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error { mc.cleanup() return nil } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. if ctx.Done() == nil { return nil } - - mc.watching = true - select { - default: - case <-ctx.Done(): - return ctx.Err() - } + // When watcher is not alive, can't watch it. if mc.watcher == nil { return nil } + mc.watching = true mc.watcher <- ctx - return nil } diff --git a/connection_test.go b/connection_test.go index dec376117..352c54ed7 100644 --- a/connection_test.go +++ b/connection_test.go @@ -9,6 +9,7 @@ package mysql import ( + "context" "database/sql/driver" "testing" ) @@ -79,3 +80,31 @@ func TestCheckNamedValue(t *testing.T) { t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value) } } + +// TestCleanCancel tests passed context is cancelled at start. +// No packet should be sent. Connection should keep current status. +func TestCleanCancel(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for i := 0; i < 3; i++ { // Repeat same behavior + err := mc.Ping(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.IsSet() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } +}