Skip to content

Cr4 -- removing buffer from mysqlConn #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func BenchmarkRoundtripBin(b *testing.B) {
length = max
}
test := sample[0:length]
rows := tb.checkRows(stmt.Query(test))
rows := tb.checkRows(stmt.Query(test)) //run benchmark tests to test that bit of code
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
Expand All @@ -231,9 +231,10 @@ func BenchmarkInterpolation(b *testing.B) {
},
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
buf: newBuffer(nil),
}
mc.reader = &mc.buf

buf := newBuffer(nil)
mc.reader = &buf

args := []driver.Value{
int64(42424242),
Expand Down
16 changes: 5 additions & 11 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ func (b *buffer) readNext(need int) ([]byte, error) {
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
func (b *buffer) reuseBuffer(length int) []byte {
if length == -1 {
return b.takeCompleteBuffer()
}

if b.length > 0 {
return nil
}
Expand All @@ -126,16 +130,6 @@ func (b *buffer) takeBuffer(length int) []byte {
return make([]byte, length)
}

// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length == 0 {
return b.buf[:length]
}
return nil
}

// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// Only one buffer (total) can be used at a time.
Expand Down
4 changes: 4 additions & 0 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) {
return data, nil
}

func (cr *compressedReader) reuseBuffer(length int) []byte {
return cr.buf.reuseBuffer(length)
}

func (cr *compressedReader) uncompressPacket() error {
header, err := cr.buf.readNext(7) // size of compressed header

Expand Down
4 changes: 4 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ func (mb *mockBuf) readNext(need int) ([]byte, error) {
return data, nil
}

func (mb *mockBuf) reuseBuffer(length int) []byte {
return make([]byte, length) //just give them a new buffer
}

// compressHelper compresses uncompressedPacket and checks state variables
func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte {
// get status variables
Expand Down
5 changes: 3 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ type mysqlContext interface {
}

type mysqlConn struct {
buf buffer
netConn net.Conn
reader packetReader
writer io.Writer
Expand All @@ -55,6 +54,7 @@ type mysqlConn struct {

type packetReader interface {
readNext(need int) ([]byte, error)
reuseBuffer(length int) []byte
}

// Handles parameters set in DSN after the connection is established
Expand Down Expand Up @@ -197,7 +197,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return "", driver.ErrSkip
}

buf := mc.buf.takeCompleteBuffer()
buf := mc.reader.reuseBuffer(-1)

if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand Down
12 changes: 6 additions & 6 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ import (

func TestInterpolateParams(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
mc.reader = &mc.buf
buf := newBuffer(nil)
mc.reader = &buf

q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
if err != nil {
Expand All @@ -36,13 +36,13 @@ func TestInterpolateParams(t *testing.T) {

func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
mc.reader = &mc.buf
buf := newBuffer(nil)
mc.reader = &buf

q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
if err != driver.ErrSkip {
Expand All @@ -54,14 +54,14 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
// https://github.com/go-sql-driver/mysql/pull/490
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}

mc.reader = &mc.buf
buf := newBuffer(nil)
mc.reader = &buf

q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
Expand Down
14 changes: 7 additions & 7 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
s.startWatcher()
}

mc.buf = newBuffer(mc.netConn)

// packet reader and writer in handshake are never compressed
mc.reader = &mc.buf
mc.writer = mc.netConn
buf := newBuffer(mc.netConn)

// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout

// packet reader and writer in handshake are never compressed
mc.reader = &buf
mc.writer = mc.netConn

// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
if err != nil {
Expand All @@ -124,7 +124,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
}

if mc.cfg.Compress {
mc.reader = newCompressedReader(&mc.buf, mc)
mc.reader = newCompressedReader(&buf, mc)
mc.writer = newCompressedWriter(mc.writer, mc)
}

Expand Down
36 changes: 25 additions & 11 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
}

// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
data := mc.reader.reuseBuffer(pktLen + 4)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand Down Expand Up @@ -326,7 +327,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
return err
}
mc.netConn = tlsConn
mc.buf.nc = tlsConn
nc := tlsConn

newBuf := newBuffer(nc)
mc.reader = &newBuf

mc.writer = mc.netConn
}
Expand Down Expand Up @@ -373,7 +377,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {

// Calculate the packet length and add a tailing 0
pktLen := len(scrambleBuff) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
data := mc.reader.reuseBuffer(4 + pktLen)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -392,7 +397,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
func (mc *mysqlConn) writeClearAuthPacket() error {
// Calculate the packet length and add a tailing 0
pktLen := len(mc.cfg.Passwd) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
data := mc.reader.reuseBuffer(4 + pktLen)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -415,7 +421,8 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {

// Calculate the packet length and add a tailing 0
pktLen := len(scrambleBuff)
data := mc.buf.takeSmallBuffer(4 + pktLen)
data := mc.reader.reuseBuffer(4 + pktLen)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -437,7 +444,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
mc.sequence = 0
mc.compressionSequence = 0

data := mc.buf.takeSmallBuffer(4 + 1)
data := mc.reader.reuseBuffer(4 + 1)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -457,7 +465,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
mc.compressionSequence = 0

pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
data := mc.reader.reuseBuffer(pktLen + 4)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -479,7 +488,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
mc.sequence = 0
mc.compressionSequence = 0

data := mc.buf.takeSmallBuffer(4 + 1 + 4)
data := mc.reader.reuseBuffer(4 + 1 + 4)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand Down Expand Up @@ -946,9 +956,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
var data []byte

if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
data = mc.reader.reuseBuffer(minPktLen)

} else {
data = mc.buf.takeCompleteBuffer()
data = mc.reader.reuseBuffer(-1)
}
if data == nil {
// can not take the buffer. Something must be wrong with the connection
Expand Down Expand Up @@ -1127,7 +1138,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data

bufBuf := mc.reader.reuseBuffer(-1)
bufBuf = data
fmt.Println(bufBuf) //dont know how to make it compile w/o some op here on bufBuf
}

pos += len(paramValues)
Expand Down
27 changes: 14 additions & 13 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,11 @@ var _ net.Conn = new(mockConn)

func TestReadPacketSingleByte(t *testing.T) {
conn := new(mockConn)
buf := newBuffer(conn)
mc := &mysqlConn{
buf: newBuffer(conn),
reader: &buf,
}

mc.reader = &mc.buf

conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
packet, err := mc.readPacket()
Expand All @@ -111,10 +110,10 @@ func TestReadPacketSingleByte(t *testing.T) {

func TestReadPacketWrongSequenceID(t *testing.T) {
conn := new(mockConn)
buf := newBuffer(conn)
mc := &mysqlConn{
buf: newBuffer(conn),
reader: &buf,
}
mc.reader = &mc.buf

// too low sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
Expand All @@ -128,7 +127,8 @@ func TestReadPacketWrongSequenceID(t *testing.T) {
// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
newBuf := newBuffer(conn)
mc.reader = &newBuf

// too high sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
Expand All @@ -140,12 +140,11 @@ func TestReadPacketWrongSequenceID(t *testing.T) {

func TestReadPacketSplit(t *testing.T) {
conn := new(mockConn)
buf := newBuffer(conn)
mc := &mysqlConn{
buf: newBuffer(conn),
reader: &buf,
}

mc.reader = &mc.buf

data := make([]byte, maxPacketSize*2+4*3)
const pkt2ofs = maxPacketSize + 4
const pkt3ofs = 2 * (maxPacketSize + 4)
Expand Down Expand Up @@ -247,11 +246,11 @@ func TestReadPacketSplit(t *testing.T) {

func TestReadPacketFail(t *testing.T) {
conn := new(mockConn)
buf := newBuffer(conn)
mc := &mysqlConn{
buf: newBuffer(conn),
reader: &buf,
closech: make(chan struct{}),
}
mc.reader = &mc.buf

// illegal empty (stand-alone) packet
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
Expand All @@ -264,7 +263,8 @@ func TestReadPacketFail(t *testing.T) {
// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
newBuf := newBuffer(conn)
mc.reader = &newBuf

// fail to read header
conn.closed = true
Expand All @@ -277,7 +277,8 @@ func TestReadPacketFail(t *testing.T) {
conn.closed = false
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
newBuf = newBuffer(conn)
mc.reader = &newBuf

// fail to read body
conn.maxReads = 1
Expand Down