Skip to content

ColumnType interfaces #667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 17, 2017
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4f5c0b7
rows: implement driver.RowsColumnTypeScanType
julienschmidt May 5, 2017
0950d1b
rows: implement driver.RowsColumnTypeNullable
julienschmidt May 5, 2017
2f97a23
rows: move fields related code to fields.go
julienschmidt May 6, 2017
1b786bd
fields: use NullTime for nullable datetime fields
julienschmidt May 6, 2017
571f082
fields: make fieldType its own type
julienschmidt May 6, 2017
b6124b5
rows: implement driver.RowsColumnTypeDatabaseTypeName
julienschmidt May 6, 2017
3ed8bb2
fields: fix copyright year
julienschmidt May 6, 2017
1820148
rows: compile time interface implementation checks
julienschmidt May 12, 2017
0570286
rows: move tests to versioned driver test files
julienschmidt May 12, 2017
3240650
rows: cache parseTime in resultSet instead of mysqlConn
julienschmidt Sep 29, 2017
163ddcd
fields: fix string and time types
julienschmidt Sep 29, 2017
91e72b0
rows: implement ColumnTypeLength
julienschmidt Oct 4, 2017
6a18c41
rows: implement ColumnTypePrecisionScale
julienschmidt Oct 4, 2017
0a5e4cb
rows: fix ColumnTypeNullable
julienschmidt Oct 4, 2017
2042d73
rows: ColumnTypes tests part1
julienschmidt Oct 4, 2017
5dc4b61
rows: use keyed composite literals in ColumnTypes tests
julienschmidt Oct 4, 2017
bb35faa
rows: ColumnTypes tests part2
julienschmidt Oct 4, 2017
b1a9d25
rows: always use NullTime as ScanType for datetime
julienschmidt Oct 4, 2017
d03077c
rows: avoid errors through rounding of time values
julienschmidt Oct 4, 2017
65f1dfb
rows: remove parseTime cache
julienschmidt Oct 5, 2017
4023d9a
fields: remove unused scanTypes
julienschmidt Oct 5, 2017
e8324ff
rows: fix ColumnTypePrecisionScale implementation
julienschmidt Oct 6, 2017
4d657f6
fields: sort types alphabetical
julienschmidt Oct 6, 2017
c60820c
rows: remove ColumnTypeLength implementation for now
julienschmidt Oct 7, 2017
6416689
README: document ColumnType Support
julienschmidt Oct 7, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ type mysqlConn struct {
flags clientFlag
status statusFlag
sequence uint8
parseTime bool

// for context support (Go 1.8+)
watching bool
Expand Down Expand Up @@ -403,6 +402,9 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
return nil, err
}
}

rows.rs.parseTime = mc.cfg.ParseTime

// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
Expand Down
6 changes: 4 additions & 2 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ const (
)

// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
type fieldType byte
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this helps taking all the casting happening in packets.go ...


const (
fieldTypeDecimal byte = iota
fieldTypeDecimal fieldType = iota
fieldTypeTiny
fieldTypeShort
fieldTypeLong
Expand All @@ -107,7 +109,7 @@ const (
fieldTypeBit
)
const (
fieldTypeJSON byte = iota + 0xf5
fieldTypeJSON fieldType = iota + 0xf5
fieldTypeNewDecimal
fieldTypeEnum
fieldTypeSet
Expand Down
1 change: 0 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime

// Connect to Server
if dial, ok := dials[mc.cfg.Net]; ok {
Expand Down
12 changes: 12 additions & 0 deletions driver_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ var (
_ driver.StmtQueryContext = &mysqlStmt{}
)

// Ensure that all the driver interfaces are implemented
var (
_ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{}
_ driver.RowsColumnTypeDatabaseTypeName = &textRows{}
_ driver.RowsColumnTypeNullable = &binaryRows{}
_ driver.RowsColumnTypeNullable = &textRows{}
_ driver.RowsColumnTypeScanType = &binaryRows{}
_ driver.RowsColumnTypeScanType = &textRows{}
_ driver.RowsNextResultSet = &binaryRows{}
_ driver.RowsNextResultSet = &textRows{}
)

func TestMultiResultSet(t *testing.T) {
type result struct {
values [][]int
Expand Down
6 changes: 6 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ import (
"time"
)

// Ensure that all the driver interfaces are implemented
var (
_ driver.Rows = &binaryRows{}
_ driver.Rows = &textRows{}
)

var (
user string
pass string
Expand Down
152 changes: 152 additions & 0 deletions fields.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

package mysql

import (
"database/sql"
"reflect"
"time"
)

var typeDatabaseName = map[fieldType]string{
fieldTypeDecimal: "DECIMAL",
fieldTypeTiny: "TINYINT",
fieldTypeShort: "SMALLINT",
fieldTypeLong: "INT",
fieldTypeFloat: "FLOAT",
fieldTypeDouble: "DOUBLE",
fieldTypeNULL: "NULL",
fieldTypeTimestamp: "TIMESTAMP",
fieldTypeLongLong: "BIGINT",
fieldTypeInt24: "MEDIUMINT",
fieldTypeDate: "DATE",
fieldTypeTime: "TIME",
fieldTypeDateTime: "DATETIME",
fieldTypeYear: "YEAR",
fieldTypeNewDate: "DATE",
fieldTypeVarChar: "VARCHAR",
fieldTypeBit: "BIT",
fieldTypeJSON: "JSON",
fieldTypeNewDecimal: "DECIMAL",
fieldTypeEnum: "ENUM",
fieldTypeSet: "SET",
fieldTypeTinyBLOB: "TINYBLOB",
fieldTypeMediumBLOB: "MEDIUMBLOB",
fieldTypeLongBLOB: "LONGBLOB",
fieldTypeBLOB: "BLOB",
fieldTypeVarString: "VARCHAR",
fieldTypeString: "CHAR",
fieldTypeGeometry: "GEOMETRY",
}

var (
scanTypeNil = reflect.TypeOf(nil)
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeInt8 = reflect.TypeOf(int8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeInt16 = reflect.TypeOf(int16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeInt32 = reflect.TypeOf(int32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeFloat32 = reflect.TypeOf(float32(0))
scanTypeFloat64 = reflect.TypeOf(float64(0))
scanTypeNullString = reflect.TypeOf(sql.NullString{})
scanTypeString = reflect.TypeOf("")
scanTypeBytes = reflect.TypeOf([]byte{})
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
scanTypeTime = reflect.TypeOf(time.Time{})
scanTypeNullTime = reflect.TypeOf(NullTime{})
scanTypeUnknown = reflect.TypeOf(new(interface{}))
)

type mysqlField struct {
tableName string
name string
flags fieldFlag
fieldType fieldType
decimals byte
}

func (mf *mysqlField) scanType(parseTime bool) reflect.Type {
switch mf.fieldType {
case fieldTypeNULL:
return scanTypeNil

case fieldTypeTiny:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint8
}
return scanTypeInt8
}
return scanTypeNullInt

case fieldTypeShort, fieldTypeYear:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint16
}
return scanTypeInt16
}
return scanTypeNullInt

case fieldTypeInt24, fieldTypeLong:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint32
}
return scanTypeInt32
}
return scanTypeNullInt

case fieldTypeLongLong:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint64
}
return scanTypeInt64
}
return scanTypeNullInt

case fieldTypeFloat:
if mf.flags&flagNotNULL != 0 {
return scanTypeFloat32
}
return scanTypeNullFloat

case fieldTypeDouble:
if mf.flags&flagNotNULL != 0 {
return scanTypeFloat64
}
return scanTypeNullFloat

case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
fieldTypeTime:
return scanTypeRawBytes
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I have to check nullable to decide converting to sql.NullString or string myself?

Copy link
Member Author

@julienschmidt julienschmidt Oct 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sql.RawBytes can also be nil, which is distinguishable from an empty string (""). I.e. if rawBytes == nil then it was a NULL value.
If you want to known if the column itself may contain NULL values, then you should check ColumnType.Nullable()


case fieldTypeDate, fieldTypeNewDate,
fieldTypeTimestamp, fieldTypeDateTime:
if parseTime {
if mf.flags&flagNotNULL != 0 {
return scanTypeTime
}
return scanTypeNullTime
}
return scanTypeRawBytes

default:
return scanTypeUnknown
}
}
22 changes: 11 additions & 11 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
pos += n + 1 + 2 + 4

// Field type [uint8]
columns[i].fieldType = data[pos]
columns[i].fieldType = fieldType(data[pos])
pos++

// Flags [uint16]
Expand Down Expand Up @@ -761,7 +761,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
pos += n
if err == nil {
if !isNull {
if !mc.parseTime {
if !rows.rs.parseTime {
continue
} else {
switch rows.rs.columns[i].fieldType {
Expand Down Expand Up @@ -980,15 +980,15 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// build NULL-bitmap
if arg == nil {
nullMask[i/8] |= 1 << (uint(i) & 7)
paramTypes[i+i] = fieldTypeNULL
paramTypes[i+i] = byte(fieldTypeNULL)
paramTypes[i+i+1] = 0x00
continue
}

// cache types and values
switch v := arg.(type) {
case int64:
paramTypes[i+i] = fieldTypeLongLong
paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x00

if cap(paramValues)-len(paramValues)-8 >= 0 {
Expand All @@ -1004,7 +1004,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}

case float64:
paramTypes[i+i] = fieldTypeDouble
paramTypes[i+i] = byte(fieldTypeDouble)
paramTypes[i+i+1] = 0x00

if cap(paramValues)-len(paramValues)-8 >= 0 {
Expand All @@ -1020,7 +1020,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}

case bool:
paramTypes[i+i] = fieldTypeTiny
paramTypes[i+i] = byte(fieldTypeTiny)
paramTypes[i+i+1] = 0x00

if v {
Expand All @@ -1032,7 +1032,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
case []byte:
// Common case (non-nil value) first
if v != nil {
paramTypes[i+i] = fieldTypeString
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00

if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
Expand All @@ -1050,11 +1050,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {

// Handle []byte(nil) as a NULL value
nullMask[i/8] |= 1 << (uint(i) & 7)
paramTypes[i+i] = fieldTypeNULL
paramTypes[i+i] = byte(fieldTypeNULL)
paramTypes[i+i+1] = 0x00

case string:
paramTypes[i+i] = fieldTypeString
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00

if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
Expand All @@ -1069,7 +1069,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}

case time.Time:
paramTypes[i+i] = fieldTypeString
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00

var a [64]byte
Expand Down Expand Up @@ -1265,7 +1265,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
)
}
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
case rows.mc.parseTime:
case rows.rs.parseTime:
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
default:
var dstlen uint8
Expand Down
25 changes: 17 additions & 8 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,14 @@ package mysql
import (
"database/sql/driver"
"io"
"reflect"
)

type mysqlField struct {
tableName string
name string
flags fieldFlag
fieldType byte
decimals byte
}

type resultSet struct {
columns []mysqlField
columnNames []string
done bool
parseTime bool // cached from cfg
}

type mysqlRows struct {
Expand Down Expand Up @@ -65,6 +59,21 @@ func (rows *mysqlRows) Columns() []string {
return columns
}

func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
if name, ok := typeDatabaseName[rows.rs.columns[i].fieldType]; ok {
return name
}
return ""
}

func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) {
return rows.rs.columns[i].flags&flagNotNULL != 0, true
}

func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type {
return rows.rs.columns[i].scanType(rows.rs.parseTime)
}

func (rows *mysqlRows) Close() (err error) {
if f := rows.finish; f != nil {
f()
Expand Down