Skip to content

Commit 5d4a831

Browse files
authored
Parse numbers on text protocol too (#1452)
1 parent 564dee9 commit 5d4a831

File tree

2 files changed

+90
-36
lines changed

2 files changed

+90
-36
lines changed

driver_test.go

+59-28
Original file line numberDiff line numberDiff line change
@@ -148,29 +148,18 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
148148
defer db2.Close()
149149
}
150150

151-
dsn3 := dsn + "&multiStatements=true"
152-
var db3 *sql.DB
153-
if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
154-
db3, err = sql.Open("mysql", dsn3)
155-
if err != nil {
156-
t.Fatalf("error connecting: %s", err.Error())
157-
}
158-
defer db3.Close()
159-
}
160-
161-
dbt := &DBTest{t, db}
162-
dbt2 := &DBTest{t, db2}
163-
dbt3 := &DBTest{t, db3}
164151
for _, test := range tests {
165-
test(dbt)
166-
dbt.db.Exec("DROP TABLE IF EXISTS test")
152+
t.Run("default", func(t *testing.T) {
153+
dbt := &DBTest{t, db}
154+
test(dbt)
155+
dbt.db.Exec("DROP TABLE IF EXISTS test")
156+
})
167157
if db2 != nil {
168-
test(dbt2)
169-
dbt2.db.Exec("DROP TABLE IF EXISTS test")
170-
}
171-
if db3 != nil {
172-
test(dbt3)
173-
dbt3.db.Exec("DROP TABLE IF EXISTS test")
158+
t.Run("interpolateParams", func(t *testing.T) {
159+
dbt2 := &DBTest{t, db2}
160+
test(dbt2)
161+
dbt2.db.Exec("DROP TABLE IF EXISTS test")
162+
})
174163
}
175164
}
176165
}
@@ -316,6 +305,48 @@ func TestCRUD(t *testing.T) {
316305
})
317306
}
318307

308+
// TestNumbers test that selecting numeric columns.
309+
// Both of textRows and binaryRows should return same type and value.
310+
func TestNumbersToAny(t *testing.T) {
311+
runTests(t, dsn, func(dbt *DBTest) {
312+
dbt.mustExec("CREATE TABLE `test` (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " +
313+
"i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE)")
314+
dbt.mustExec("INSERT INTO `test` VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)")
315+
316+
// Use binaryRows for intarpolateParams=false and textRows for intarpolateParams=true.
317+
rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM `test` WHERE id=?", 1)
318+
if !rows.Next() {
319+
dbt.Fatal("no data")
320+
}
321+
var b, i8, i16, i32, i64, f32, f64 any
322+
err := rows.Scan(&b, &i8, &i16, &i32, &i64, &f32, &f64)
323+
if err != nil {
324+
dbt.Fatal(err)
325+
}
326+
if b.(int64) != 1 {
327+
dbt.Errorf("b != 1")
328+
}
329+
if i8.(int64) != 127 {
330+
dbt.Errorf("i8 != 127")
331+
}
332+
if i16.(int64) != 32767 {
333+
dbt.Errorf("i16 != 32767")
334+
}
335+
if i32.(int64) != 2147483647 {
336+
dbt.Errorf("i32 != 2147483647")
337+
}
338+
if i64.(int64) != 9223372036854775807 {
339+
dbt.Errorf("i64 != 9223372036854775807")
340+
}
341+
if f32.(float32) != 1.25 {
342+
dbt.Errorf("f32 != 1.25")
343+
}
344+
if f64.(float64) != 2.5 {
345+
dbt.Errorf("f64 != 2.5")
346+
}
347+
})
348+
}
349+
319350
func TestMultiQuery(t *testing.T) {
320351
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
321352
// Create Table
@@ -1808,13 +1839,13 @@ func TestConcurrent(t *testing.T) {
18081839
}
18091840

18101841
runTests(t, dsn, func(dbt *DBTest) {
1811-
var version string
1812-
if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
1813-
dbt.Fatalf("%s", err.Error())
1814-
}
1815-
if strings.Contains(strings.ToLower(version), "mariadb") {
1816-
t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`)
1817-
}
1842+
// var version string
1843+
// if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
1844+
// dbt.Fatal(err)
1845+
// }
1846+
// if strings.Contains(strings.ToLower(version), "mariadb") {
1847+
// t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`)
1848+
// }
18181849

18191850
var max int
18201851
err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)

packets.go

+31-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"fmt"
1818
"io"
1919
"math"
20+
"strconv"
2021
"time"
2122
)
2223

@@ -834,7 +835,8 @@ func (rows *textRows) readRow(dest []driver.Value) error {
834835

835836
for i := range dest {
836837
// Read bytes and convert to string
837-
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
838+
var buf []byte
839+
buf, isNull, n, err = readLengthEncodedString(data[pos:])
838840
pos += n
839841

840842
if err != nil {
@@ -846,19 +848,40 @@ func (rows *textRows) readRow(dest []driver.Value) error {
846848
continue
847849
}
848850

849-
if !mc.parseTime {
850-
continue
851-
}
852-
853-
// Parse time field
854851
switch rows.rs.columns[i].fieldType {
855852
case fieldTypeTimestamp,
856853
fieldTypeDateTime,
857854
fieldTypeDate,
858855
fieldTypeNewDate:
859-
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil {
860-
return err
856+
if mc.parseTime {
857+
dest[i], err = parseDateTime(buf, mc.cfg.Loc)
858+
} else {
859+
dest[i] = buf
860+
}
861+
862+
case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
863+
dest[i], err = strconv.ParseInt(string(buf), 10, 32)
864+
865+
case fieldTypeLongLong:
866+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
867+
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
868+
} else {
869+
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
861870
}
871+
872+
case fieldTypeFloat:
873+
var d float64
874+
d, err = strconv.ParseFloat(string(buf), 32)
875+
dest[i] = float32(d)
876+
877+
case fieldTypeDouble:
878+
dest[i], err = strconv.ParseFloat(string(buf), 64)
879+
880+
default:
881+
dest[i] = buf
882+
}
883+
if err != nil {
884+
return err
862885
}
863886
}
864887

0 commit comments

Comments
 (0)