diff --git a/AUTHORS b/AUTHORS index 4702c83ab..61ba24d4a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -13,6 +13,7 @@ Aaron Hopkins Achille Roussel +Andrew Reid Arne Hormann Asta Xie Bulat Gaifullin diff --git a/driver_test.go b/driver_test.go index 7877aa979..2a9819381 100644 --- a/driver_test.go +++ b/driver_test.go @@ -547,6 +547,7 @@ func TestValuerWithValidation(t *testing.T) { var out string var rows *sql.Rows + dbt.mustExec("DROP TABLE IF EXISTS testValuer") dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) @@ -570,6 +571,10 @@ func TestValuerWithValidation(t *testing.T) { dbt.Errorf("Failed to check nil") } + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", (*testValuerWithValidation)(nil)); err != nil { + dbt.Errorf("Failed to check typed nil") + } + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { dbt.Errorf("Failed to check not valuer") } diff --git a/statement.go b/statement.go index 98e57bcd8..435c57e90 100644 --- a/statement.go +++ b/statement.go @@ -10,7 +10,6 @@ package mysql import ( "database/sql/driver" - "fmt" "io" "reflect" "strconv" @@ -132,47 +131,33 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { type converter struct{} +// ConvertValue differs from defaultConverter.ConverValue for uint64 with the high bit set only +// all other conversion requests return driver.ErrSkip to defer to the default converter func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if driver.IsValue(v) { return v, nil } - if v != nil { - if valuer, ok := v.(driver.Valuer); ok { - return valuer.Value() - } + // even when uint64 is the underlying type, a custom Valuer should take precedence + if _, ok := v.(driver.Valuer); ok { + return v, driver.ErrSkip } rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: - // indirect pointers if rv.IsNil() { return nil, nil } + // recursively handle *uint64, **uint64 etc return c.ConvertValue(rv.Elem().Interface()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return rv.Int(), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - return int64(rv.Uint()), nil case reflect.Uint64: u64 := rv.Uint() if u64 >= 1<<63 { + // The defaultConverter errors in this case - we convert to a string return strconv.FormatUint(u64, 10), nil } - return int64(u64), nil - case reflect.Float32, reflect.Float64: - return rv.Float(), nil - case reflect.Bool: - return rv.Bool(), nil - case reflect.Slice: - ek := rv.Type().Elem().Kind() - if ek == reflect.Uint8 { - return rv.Bytes(), nil - } - return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) - case reflect.String: - return rv.String(), nil } - return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) + + return v, driver.ErrSkip } diff --git a/statement_test.go b/statement_test.go index 98a6c1933..6dbfcee29 100644 --- a/statement_test.go +++ b/statement_test.go @@ -9,118 +9,116 @@ package mysql import ( - "bytes" + "database/sql/driver" "testing" + "time" ) -func TestConvertDerivedString(t *testing.T) { - type derived string +func TestValueThatIsValue(t *testing.T) { + now := time.Now() + inputs := []interface{}{nil, float64(1.0), int64(17), "ABC", now} - output, err := converter{}.ConvertValue(derived("value")) - if err != nil { - t.Fatal("Derived string type not convertible", err) - } - - if output != "value" { - t.Fatalf("Derived string type not converted, got %#v %T", output, output) + for _, in := range inputs { + out, err := converter{}.ConvertValue(in) + if err != nil { + t.Fatalf("Value %#v %T not needing conversion caused error: %s", in, in, err) + } + if out != in { + t.Fatalf("Value %#v %T altered in conversion got %#v %T", in, in, out, out) + } } } -func TestConvertDerivedByteSlice(t *testing.T) { - type derived []uint8 +func TestValueThatIsPtrToValue(t *testing.T) { + w := "ABC" + x := &w + y := &x + inputs := []interface{}{x, y} - output, err := converter{}.ConvertValue(derived("value")) - if err != nil { - t.Fatal("Byte slice not convertible", err) - } - - if bytes.Compare(output.([]byte), []byte("value")) != 0 { - t.Fatalf("Byte slice not converted, got %#v %T", output, output) + for _, in := range inputs { + out, err := converter{}.ConvertValue(in) + if err != nil { + t.Fatalf("Pointer %#v %T to value not needing conversion caused error: %s", in, in, err) + } + if out != w { + t.Fatalf("Value %#v %T not resolved to string in conversion (got %#v %T)", in, in, out, out) + } } } -func TestConvertDerivedUnsupportedSlice(t *testing.T) { - type derived []int +func TestValueThatIsTypedPtrToNil(t *testing.T) { + var w *string + x := &w + y := &x + inputs := []interface{}{x, y} - _, err := converter{}.ConvertValue(derived{1}) - if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" { - t.Fatal("Unexpected error", err) + for _, in := range inputs { + out, err := converter{}.ConvertValue(in) + if err != nil { + t.Fatalf("Pointer %#v %T to nil value caused error: %s", in, in, err) + } + if out != nil { + t.Fatalf("Pointer to nil did not Value as nil") + } } } -func TestConvertDerivedBool(t *testing.T) { - type derived bool +type implementsValuer uint64 - output, err := converter{}.ConvertValue(derived(true)) - if err != nil { - t.Fatal("Derived bool type not convertible", err) - } - - if output != true { - t.Fatalf("Derived bool type not converted, got %#v %T", output, output) - } +func (me implementsValuer) Value() (driver.Value, error) { + return string(me), nil } - -func TestConvertPointer(t *testing.T) { - str := "value" - - output, err := converter{}.ConvertValue(&str) - if err != nil { - t.Fatal("Pointer type not convertible", err) - } - - if output != "value" { - t.Fatalf("Pointer type not converted, got %#v %T", output, output) +func TestTypesThatImplementValuerAreSkipped(t *testing.T) { + // Have to test on a uint64 with high bit set - as we skip everything else anyhow + x := implementsValuer(^uint64(0)) + y := &x + z := &y + var a *implementsValuer + b := &a + c := &b + inputs := []interface{}{x, y, z, a, b, c} + + for _, in := range inputs { + _, err := converter{}.ConvertValue(in) + if err != driver.ErrSkip { + t.Fatalf("Conversion of Valuer implementing type %T not skipped", in) + } } } -func TestConvertSignedIntegers(t *testing.T) { - values := []interface{}{ - int8(-42), - int16(-42), - int32(-42), - int64(-42), - int(-42), - } - - for _, value := range values { - output, err := converter{}.ConvertValue(value) - if err != nil { - t.Fatalf("%T type not convertible %s", value, err) - } - - if output != int64(-42) { - t.Fatalf("%T type not converted, got %#v %T", value, output, output) +func TestTypesThatAreNotValuesAreSkipped(t *testing.T) { + type derived1 string // convertable + type derived2 []uint8 // convertable + type derived3 []int // not convertable + type derived4 uint64 // without the high bit set + inputs := []interface{}{derived1("ABC"), derived2([]uint8{'A', 'B'}), derived3([]int{17, 32}), derived3(nil), derived4(26)} + + for _, in := range inputs { + _, err := converter{}.ConvertValue(in) + if err != driver.ErrSkip { + t.Fatalf("Conversion of non-value value %#v %T not skipped", in, in) } } } -func TestConvertUnsignedIntegers(t *testing.T) { - values := []interface{}{ - uint8(42), - uint16(42), - uint32(42), - uint64(42), - uint(42), - } +func TestConvertLargeUnsignedIntegers(t *testing.T) { + type derived uint64 + type derived2 *uint64 + v := ^uint64(0) + w := &v + x := derived(v) + y := &x + z := derived2(w) - for _, value := range values { - output, err := converter{}.ConvertValue(value) + inputs := []interface{}{v, w, x, y, z} + + for _, in := range inputs { + out, err := converter{}.ConvertValue(in) if err != nil { - t.Fatalf("%T type not convertible %s", value, err) + t.Fatalf("uint64 high-bit not convertible for type %T", in) } - - if output != int64(42) { - t.Fatalf("%T type not converted, got %#v %T", value, output, output) + if out != "18446744073709551615" { + t.Fatalf("uint64 high-bit not converted, got %#v %T", out, out) } } - - output, err := converter{}.ConvertValue(^uint64(0)) - if err != nil { - t.Fatal("uint64 high-bit not convertible", err) - } - - if output != "18446744073709551615" { - t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output) - } }