diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f541e866..3f639ad2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Changes: - Go-MySQL-Driver now requires Go 1.1 - Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore - Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors + - `byte(nil)` is now treated as a NULL value. Before it was treated like an empty string / `[]byte("")`. - New Logo - Changed the copyright header to include all contributors - Optimized the buffer for reading @@ -28,6 +29,7 @@ Bugfixes: - Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification - Convert to DB timezone when inserting time.Time - Splitted packets (more than 16MB) are now merged correctly + - Fixed empty string producing false nil values ## 1.0 (2013-05-14) diff --git a/driver_test.go b/driver_test.go index d4422075f..4c6c44710 100644 --- a/driver_test.go +++ b/driver_test.go @@ -108,143 +108,6 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) return rows } -func TestReuseClosedConnection(t *testing.T) { - // this test does not use sql.database, it uses the driver directly - if !available { - t.Skipf("MySQL-Server not running on %s", netAddr) - } - driver := &MySQLDriver{} - conn, err := driver.Open(dsn) - if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) - } - stmt, err := conn.Prepare("DO 1") - if err != nil { - t.Fatalf("Error preparing statement: %s", err.Error()) - } - _, err = stmt.Exec(nil) - if err != nil { - t.Fatalf("Error executing statement: %s", err.Error()) - } - err = conn.Close() - if err != nil { - t.Fatalf("Error closing connection: %s", err.Error()) - } - defer func() { - if err := recover(); err != nil { - t.Errorf("Panic after reusing a closed connection: %v", err) - } - }() - _, err = stmt.Exec(nil) - if err != nil && err != errInvalidConn { - t.Errorf("Unexpected error '%s', expected '%s'", - err.Error(), errInvalidConn.Error()) - } -} - -func TestCharset(t *testing.T) { - if !available { - t.Skipf("MySQL-Server not running on %s", netAddr) - } - - mustSetCharset := func(charsetParam, expected string) { - runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { - rows := dbt.mustQuery("SELECT @@character_set_connection") - defer rows.Close() - - if !rows.Next() { - dbt.Fatalf("Error getting connection charset: %s", rows.Err()) - } - - var got string - rows.Scan(&got) - - if got != expected { - dbt.Fatalf("Expected connection charset %s but got %s", expected, got) - } - }) - } - - // non utf8 test - mustSetCharset("charset=ascii", "ascii") - - // when the first charset is invalid, use the second - mustSetCharset("charset=none,utf8", "utf8") - - // when the first charset is valid, use it - mustSetCharset("charset=ascii,utf8", "ascii") - mustSetCharset("charset=utf8,ascii", "utf8") -} - -func TestFailingCharset(t *testing.T) { - runTests(t, dsn+"&charset=none", func(dbt *DBTest) { - // run query to really establish connection... - _, err := dbt.db.Exec("SELECT 1") - if err == nil { - dbt.db.Close() - t.Fatalf("Connection must not succeed without a valid charset") - } - }) -} - -func TestRawBytesResultExceedsBuffer(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - // defaultBufSize from buffer.go - expected := strings.Repeat("abc", defaultBufSize) - rows := dbt.mustQuery("SELECT '" + expected + "'") - defer rows.Close() - if !rows.Next() { - dbt.Error("expected result, got none") - } - var result sql.RawBytes - rows.Scan(&result) - if expected != string(result) { - dbt.Error("result did not match expected value") - } - }) -} - -func TestTimezoneConversion(t *testing.T) { - - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} - - // Regression test for timezone handling - tzTest := func(dbt *DBTest) { - - // Create table - dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") - - // Insert local time into database (should be converted) - usCentral, _ := time.LoadLocation("US/Central") - now := time.Now().In(usCentral) - dbt.mustExec("INSERT INTO test VALUE (?)", now) - - // Retrieve time from DB - rows := dbt.mustQuery("SELECT ts FROM test") - if !rows.Next() { - dbt.Fatal("Didn't get any rows out") - } - - var nowDB time.Time - err := rows.Scan(&nowDB) - if err != nil { - dbt.Fatal("Err", err) - } - - // Check that dates match - if now.Unix() != nowDB.Unix() { - dbt.Errorf("Times don't match.\n") - dbt.Errorf(" Now(%v)=%v\n", usCentral, now) - dbt.Errorf(" Now(UTC)=%v\n", nowDB) - } - - } - - for _, tz := range zones { - runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) - } -} - func TestCRUD(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // Create Table @@ -548,44 +411,6 @@ func TestDateTime(t *testing.T) { } } -// This tests for https://github.com/go-sql-driver/mysql/pull/139 -// -// An extra (invisible) nil byte was being added to the beginning of positive -// time strings. -func TestTimeSign(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - var sTimes = []struct { - value string - fieldType string - }{ - {"12:34:56", "TIME"}, - {"-12:34:56", "TIME"}, - // As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html - // they *should* work, but only in 5.6+. - // { "12:34:56.789", "TIME(3)" }, - // { "-12:34:56.789", "TIME(3)" }, - } - - for _, sTime := range sTimes { - dbt.db.Exec("DROP TABLE IF EXISTS test") - dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")") - dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')") - rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1) - if rows.Next() { - var oTime string - rows.Scan(&oTime) - if oTime != sTime.value { - dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value) - } - } else { - dbt.Error("expecting at least one row.") - } - } - - }) - -} - func TestNULL(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { nullStmt, err := dbt.db.Prepare("SELECT NULL") @@ -603,16 +428,14 @@ func TestNULL(t *testing.T) { // NullBool var nb sql.NullBool // Invalid - err = nullStmt.QueryRow().Scan(&nb) - if err != nil { + if err = nullStmt.QueryRow().Scan(&nb); err != nil { dbt.Fatal(err) } if nb.Valid { dbt.Error("Valid NullBool which should be invalid") } // Valid - err = nonNullStmt.QueryRow().Scan(&nb) - if err != nil { + if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { dbt.Fatal(err) } if !nb.Valid { @@ -624,16 +447,14 @@ func TestNULL(t *testing.T) { // NullFloat64 var nf sql.NullFloat64 // Invalid - err = nullStmt.QueryRow().Scan(&nf) - if err != nil { + if err = nullStmt.QueryRow().Scan(&nf); err != nil { dbt.Fatal(err) } if nf.Valid { dbt.Error("Valid NullFloat64 which should be invalid") } // Valid - err = nonNullStmt.QueryRow().Scan(&nf) - if err != nil { + if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { dbt.Fatal(err) } if !nf.Valid { @@ -645,16 +466,14 @@ func TestNULL(t *testing.T) { // NullInt64 var ni sql.NullInt64 // Invalid - err = nullStmt.QueryRow().Scan(&ni) - if err != nil { + if err = nullStmt.QueryRow().Scan(&ni); err != nil { dbt.Fatal(err) } if ni.Valid { dbt.Error("Valid NullInt64 which should be invalid") } // Valid - err = nonNullStmt.QueryRow().Scan(&ni) - if err != nil { + if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { dbt.Fatal(err) } if !ni.Valid { @@ -666,16 +485,14 @@ func TestNULL(t *testing.T) { // NullString var ns sql.NullString // Invalid - err = nullStmt.QueryRow().Scan(&ns) - if err != nil { + if err = nullStmt.QueryRow().Scan(&ns); err != nil { dbt.Fatal(err) } if ns.Valid { dbt.Error("Valid NullString which should be invalid") } // Valid - err = nonNullStmt.QueryRow().Scan(&ns) - if err != nil { + if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { dbt.Fatal(err) } if !ns.Valid { @@ -684,6 +501,48 @@ func TestNULL(t *testing.T) { dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)") } + // nil-bytes + var b []byte + // Read nil + if err = nullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("Non-nil []byte wich should be nil") + } + // Read non-nil + if err = nonNullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("Nil []byte wich should be non-nil") + } + // Insert nil + b = nil + success := false + if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + dbt.Fatal(err) + } + if !success { + dbt.Error("Inserting []byte(nil) as NULL failed") + } + // Check input==output with input==nil + b = nil + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("Non-nil echo from nil input") + } + // Check input==output with input!=nil + b = []byte("") + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil echo from non-nil input") + } + // Insert NULL dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") @@ -995,6 +854,180 @@ func TestTLS(t *testing.T) { runTests(t, dsn+"&tls=custom-skip-verify", tlsTest) } +func TestReuseClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Skipf("MySQL-Server not running on %s", netAddr) + } + + driver := &MySQLDriver{} + conn, err := driver.Open(dsn) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("DO 1") + if err != nil { + t.Fatalf("Error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("Error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("Error closing connection: %s", err.Error()) + } + + defer func() { + if err := recover(); err != nil { + t.Errorf("Panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != nil && err != errInvalidConn { + t.Errorf("Unexpected error '%s', expected '%s'", + err.Error(), errInvalidConn.Error()) + } +} + +func TestCharset(t *testing.T) { + if !available { + t.Skipf("MySQL-Server not running on %s", netAddr) + } + + mustSetCharset := func(charsetParam, expected string) { + runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT @@character_set_connection") + defer rows.Close() + + if !rows.Next() { + dbt.Fatalf("Error getting connection charset: %s", rows.Err()) + } + + var got string + rows.Scan(&got) + + if got != expected { + dbt.Fatalf("Expected connection charset %s but got %s", expected, got) + } + }) + } + + // non utf8 test + mustSetCharset("charset=ascii", "ascii") + + // when the first charset is invalid, use the second + mustSetCharset("charset=none,utf8", "utf8") + + // when the first charset is valid, use it + mustSetCharset("charset=ascii,utf8", "ascii") + mustSetCharset("charset=utf8,ascii", "utf8") +} + +func TestFailingCharset(t *testing.T) { + runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + // run query to really establish connection... + _, err := dbt.db.Exec("SELECT 1") + if err == nil { + dbt.db.Close() + t.Fatalf("Connection must not succeed without a valid charset") + } + }) +} + +func TestRawBytesResultExceedsBuffer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // defaultBufSize from buffer.go + expected := strings.Repeat("abc", defaultBufSize) + + rows := dbt.mustQuery("SELECT '" + expected + "'") + defer rows.Close() + if !rows.Next() { + dbt.Error("expected result, got none") + } + var result sql.RawBytes + rows.Scan(&result) + if expected != string(result) { + dbt.Error("result did not match expected value") + } + }) +} + +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + + // Regression test for timezone handling + tzTest := func(dbt *DBTest) { + + // Create table + dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") + + // Insert local time into database (should be converted) + usCentral, _ := time.LoadLocation("US/Central") + now := time.Now().In(usCentral) + dbt.mustExec("INSERT INTO test VALUE (?)", now) + + // Retrieve time from DB + rows := dbt.mustQuery("SELECT ts FROM test") + if !rows.Next() { + dbt.Fatal("Didn't get any rows out") + } + + var nowDB time.Time + err := rows.Scan(&nowDB) + if err != nil { + dbt.Fatal("Err", err) + } + + // Check that dates match + if now.Unix() != nowDB.Unix() { + dbt.Errorf("Times don't match.\n") + dbt.Errorf(" Now(%v)=%v\n", usCentral, now) + dbt.Errorf(" Now(UTC)=%v\n", nowDB) + } + } + + for _, tz := range zones { + runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) + } +} + +// This tests for https://github.com/go-sql-driver/mysql/pull/139 +// +// An extra (invisible) nil byte was being added to the beginning of positive +// time strings. +func TestTimeSign(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + var sTimes = []struct { + value string + fieldType string + }{ + {"12:34:56", "TIME"}, + {"-12:34:56", "TIME"}, + // As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html + // they *should* work, but only in 5.6+. + // { "12:34:56.789", "TIME(3)" }, + // { "-12:34:56.789", "TIME(3)" }, + } + + for _, sTime := range sTimes { + dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")") + dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')") + rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1) + if rows.Next() { + var oTime string + rows.Scan(&oTime) + if oTime != sTime.value { + dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value) + } + } else { + dbt.Error("expecting at least one row.") + } + } + }) +} + // Special cases func TestRowsClose(t *testing.T) { diff --git a/packets.go b/packets.go index aff0fad62..d2ad33d09 100644 --- a/packets.go +++ b/packets.go @@ -915,20 +915,29 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case []byte: - paramTypes[i+i] = fieldTypeString - paramTypes[i+i+1] = 0x00 - - if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { - paramValues = appendLengthEncodedInteger(paramValues, - uint64(len(v)), - ) - paramValues = append(paramValues, v...) - } else { - if err := stmt.writeCommandLongData(i, v); err != nil { - return err + // Common case (non-nil value) first + if v != nil { + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, v); err != nil { + return err + } } + continue } + // Handle []byte(nil) as a NULL value + nullMask |= 1 << uint(i) + paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i+1] = 0x00 + case string: paramTypes[i+i] = fieldTypeString paramTypes[i+i+1] = 0x00 diff --git a/utils.go b/utils.go index a58855e49..916ebefe5 100644 --- a/utils.go +++ b/utils.go @@ -600,11 +600,14 @@ func stringToInt(b []byte) int { return val } +// returns the string read as a bytes slice, wheter the value is NULL, +// the number of bytes read and an error, in case the string is longer than +// the input slice func readLengthEnodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { - return nil, isNull, n, nil + return b[n:n], isNull, n, nil } n += int(num) @@ -616,6 +619,8 @@ func readLengthEnodedString(b []byte) ([]byte, bool, int, error) { return nil, false, n, io.EOF } +// returns the number of bytes skipped and an error, in case the string is +// longer than the input slice func skipLengthEnodedString(b []byte) (int, error) { // Get length num, _, n := readLengthEncodedInteger(b) @@ -632,42 +637,35 @@ func skipLengthEnodedString(b []byte) (int, error) { return n, io.EOF } -func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) { +// returns the number read, whether the value is NULL and the number of bytes read +func readLengthEncodedInteger(b []byte) (uint64, bool, int) { switch b[0] { // 251: NULL case 0xfb: - n = 1 - isNull = true - return + return 0, true, 1 // 252: value of following 2 case 0xfc: - num = uint64(b[1]) | uint64(b[2])<<8 - n = 3 - return + return uint64(b[1]) | uint64(b[2])<<8, false, 3 // 253: value of following 3 case 0xfd: - num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 - n = 4 - return + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 // 254: value of following 8 case 0xfe: - num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | - uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | - uint64(b[7])<<48 | uint64(b[8])<<54 - n = 9 - return + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | + uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | + uint64(b[7])<<48 | uint64(b[8])<<54, + false, 9 } // 0-250: value of first byte - num = uint64(b[0]) - n = 1 - return + return uint64(b[0]), false, 1 } +// encodes a uint64 value and appends it to the given bytes slice func appendLengthEncodedInteger(b []byte, n uint64) []byte { switch { case n <= 250: