Skip to content

Commit 77f6792

Browse files
author
Brigitte Lamarche
committed
packets: implemented compression protocol CR changes
1 parent e6c682c commit 77f6792

File tree

4 files changed

+72
-26
lines changed

4 files changed

+72
-26
lines changed

benchmark_go18_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
4242
}
4343

4444
func BenchmarkQueryContext(b *testing.B) {
45-
db := initDB(b,
45+
db := initDB(b, false,
4646
"DROP TABLE IF EXISTS foo",
4747
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
4848
`INSERT INTO foo VALUES (1, "one")`,
@@ -78,7 +78,7 @@ func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
7878
}
7979

8080
func BenchmarkExecContext(b *testing.B) {
81-
db := initDB(b,
81+
db := initDB(b, false,
8282
"DROP TABLE IF EXISTS foo",
8383
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
8484
`INSERT INTO foo VALUES (1, "one")`,

benchmark_test.go

+16-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
4343
return stmt
4444
}
4545

46-
func initDB(b *testing.B, queries ...string) *sql.DB {
46+
func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB {
4747
tb := (*TB)(b)
48-
db := tb.checkDB(sql.Open("mysql", dsn))
48+
comprStr := ""
49+
if useCompression {
50+
comprStr = "&compress=1"
51+
}
52+
db := tb.checkDB(sql.Open("mysql", dsn+comprStr))
4953
for _, query := range queries {
5054
if _, err := db.Exec(query); err != nil {
5155
if w, ok := err.(MySQLWarnings); ok {
@@ -61,10 +65,19 @@ func initDB(b *testing.B, queries ...string) *sql.DB {
6165
const concurrencyLevel = 10
6266

6367
func BenchmarkQuery(b *testing.B) {
68+
benchmarkQueryHelper(b, false)
69+
}
70+
71+
func BenchmarkQueryCompression(b *testing.B) {
72+
benchmarkQueryHelper(b, true)
73+
}
74+
75+
func benchmarkQueryHelper(b *testing.B, compr bool) {
76+
6477
tb := (*TB)(b)
6578
b.StopTimer()
6679
b.ReportAllocs()
67-
db := initDB(b,
80+
db := initDB(b, compr,
6881
"DROP TABLE IF EXISTS foo",
6982
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
7083
`INSERT INTO foo VALUES (1, "one")`,

compress.go

+51-21
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ type compressedReader struct {
1818
buf packetReader
1919
bytesBuf []byte
2020
mc *mysqlConn
21+
br *bytes.Reader
22+
zr io.ReadCloser
2123
}
2224

2325
type compressedWriter struct {
@@ -48,12 +50,8 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) {
4850
}
4951
}
5052

51-
data := make([]byte, need)
52-
53-
copy(data, cr.bytesBuf[:len(data)])
54-
55-
cr.bytesBuf = cr.bytesBuf[len(data):]
56-
53+
data := cr.bytesBuf[:need]
54+
cr.bytesBuf = cr.bytesBuf[need:]
5755
return data, nil
5856
}
5957

@@ -88,27 +86,43 @@ func (cr *compressedReader) uncompressPacket() error {
8886
}
8987

9088
// write comprData to a bytes.buffer, then read it using zlib into data
91-
var b bytes.Buffer
92-
b.Write(comprData)
93-
r, err := zlib.NewReader(&b)
89+
if cr.br == nil {
90+
cr.br = bytes.NewReader(comprData)
91+
} else {
92+
cr.br.Reset(comprData)
93+
}
94+
95+
resetter, ok := cr.zr.(zlib.Resetter)
9496

95-
if r != nil {
96-
defer r.Close()
97+
if ok {
98+
err := resetter.Reset(cr.br, []byte{})
99+
if err != nil {
100+
return err
101+
}
102+
} else {
103+
cr.zr, err = zlib.NewReader(cr.br)
104+
if err != nil {
105+
return err
106+
}
97107
}
98108

99-
if err != nil {
100-
return err
109+
defer cr.zr.Close()
110+
111+
//use existing capacity in bytesBuf if possible
112+
offset := len(cr.bytesBuf)
113+
if cap(cr.bytesBuf)-offset < uncompressedLength {
114+
old := cr.bytesBuf
115+
cr.bytesBuf = make([]byte, offset, offset+uncompressedLength)
116+
copy(cr.bytesBuf, old)
101117
}
102118

103-
data := make([]byte, uncompressedLength)
119+
data := cr.bytesBuf[offset : offset+uncompressedLength]
120+
104121
lenRead := 0
105122

106123
// http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate
107124
for lenRead < uncompressedLength {
108-
109-
tmp := data[lenRead:]
110-
111-
n, err := r.Read(tmp)
125+
n, err := cr.zr.Read(data[lenRead:])
112126
lenRead += n
113127

114128
if err == io.EOF {
@@ -152,7 +166,15 @@ func (cw *compressedWriter) Write(data []byte) (int, error) {
152166
return 0, err
153167
}
154168

155-
err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall)
169+
// if compression expands the payload, do not compress
170+
useData := b.Bytes()
171+
172+
if len(useData) > len(dataSmall) {
173+
useData = dataSmall
174+
lenSmall = 0
175+
}
176+
177+
err = cw.writeComprPacketToNetwork(useData, lenSmall)
156178
if err != nil {
157179
return 0, err
158180
}
@@ -163,7 +185,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) {
163185

164186
lenSmall := len(data)
165187

166-
// do not compress if packet is too small
188+
// do not attempt compression if packet is too small
167189
if lenSmall < minCompressLength {
168190
err := cw.writeComprPacketToNetwork(data, 0)
169191
if err != nil {
@@ -183,7 +205,15 @@ func (cw *compressedWriter) Write(data []byte) (int, error) {
183205
return 0, err
184206
}
185207

186-
err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall)
208+
// if compression expands the payload, do not compress
209+
useData := b.Bytes()
210+
211+
if len(useData) > len(data) {
212+
useData = data
213+
lenSmall = 0
214+
}
215+
216+
err = cw.writeComprPacketToNetwork(useData, lenSmall)
187217

188218
if err != nil {
189219
return 0, err

packets.go

+3
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
881881
}
882882

883883
stmt.mc.sequence = 0
884+
stmt.mc.compressionSequence = 0
884885
// Add command byte [1 byte]
885886
data[4] = comStmtSendLongData
886887

@@ -906,6 +907,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
906907

907908
// Reset Packet Sequence
908909
stmt.mc.sequence = 0
910+
stmt.mc.compressionSequence = 0
909911
return nil
910912
}
911913

@@ -925,6 +927,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
925927

926928
// Reset packet-sequence
927929
mc.sequence = 0
930+
mc.compressionSequence = 0
928931

929932
var data []byte
930933

0 commit comments

Comments
 (0)