Skip to content

Commit 0c65ab7

Browse files
committed
remove errBadConnNoWrite and markBadConn
1 parent af8d793 commit 0c65ab7

File tree

5 files changed

+46
-72
lines changed

5 files changed

+46
-72
lines changed

connection.go

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,12 @@ func (mc *mysqlConn) handleParams() (err error) {
9999
return
100100
}
101101

102-
func (mc *mysqlConn) markBadConn(err error) error {
103-
if mc == nil {
104-
return err
105-
}
106-
if err != errBadConnNoWrite {
107-
return err
108-
}
109-
return driver.ErrBadConn
110-
}
111-
112102
func (mc *mysqlConn) Begin() (driver.Tx, error) {
113103
return mc.begin(false)
114104
}
115105

116106
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
117107
if mc.closed.Load() {
118-
mc.log(ErrInvalidConn)
119108
return nil, driver.ErrBadConn
120109
}
121110
var q string
@@ -128,7 +117,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
128117
if err == nil {
129118
return &mysqlTx{mc}, err
130119
}
131-
return nil, mc.markBadConn(err)
120+
return nil, err
132121
}
133122

134123
func (mc *mysqlConn) Close() (err error) {
@@ -177,15 +166,12 @@ func (mc *mysqlConn) error() error {
177166

178167
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
179168
if mc.closed.Load() {
180-
mc.log(ErrInvalidConn)
181169
return nil, driver.ErrBadConn
182170
}
183171
// Send command
184172
err := mc.writeCommandPacketStr(comStmtPrepare, query)
185173
if err != nil {
186-
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
187-
mc.log(err)
188-
return nil, driver.ErrBadConn
174+
return nil, err
189175
}
190176

191177
stmt := &mysqlStmt{
@@ -218,8 +204,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
218204
buf, err := mc.buf.takeCompleteBuffer()
219205
if err != nil {
220206
// can not take the buffer. Something must be wrong with the connection
221-
mc.log(err)
222-
return "", ErrInvalidConn
207+
mc.cleanup()
208+
return "", err
223209
}
224210
buf = buf[:0]
225211
argPos := 0
@@ -310,7 +296,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
310296

311297
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
312298
if mc.closed.Load() {
313-
mc.log(ErrInvalidConn)
314299
return nil, driver.ErrBadConn
315300
}
316301
if len(args) != 0 {
@@ -330,15 +315,15 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
330315
copied := mc.result
331316
return &copied, err
332317
}
333-
return nil, mc.markBadConn(err)
318+
return nil, err
334319
}
335320

336321
// Internal function to execute commands
337322
func (mc *mysqlConn) exec(query string) error {
338323
handleOk := mc.clearResult()
339324
// Send command
340325
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
341-
return mc.markBadConn(err)
326+
return err
342327
}
343328

344329
// Read Result
@@ -370,7 +355,6 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
370355
handleOk := mc.clearResult()
371356

372357
if mc.closed.Load() {
373-
mc.log(ErrInvalidConn)
374358
return nil, driver.ErrBadConn
375359
}
376360
if len(args) != 0 {
@@ -410,7 +394,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
410394
return rows, err
411395
}
412396
}
413-
return nil, mc.markBadConn(err)
397+
return nil, err
414398
}
415399

416400
// Gets the value of the given MySQL System Variable
@@ -465,7 +449,6 @@ func (mc *mysqlConn) finish() {
465449
// Ping implements driver.Pinger interface
466450
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
467451
if mc.closed.Load() {
468-
mc.log(ErrInvalidConn)
469452
return driver.ErrBadConn
470453
}
471454

@@ -476,7 +459,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
476459

477460
handleOk := mc.clearResult()
478461
if err = mc.writeCommandPacket(comPing); err != nil {
479-
return mc.markBadConn(err)
462+
return err
480463
}
481464

482465
return handleOk.readResultOK()
@@ -682,8 +665,12 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
682665
return nil
683666
}
684667

668+
var _ driver.SessionResetter = &mysqlConn{}
669+
685670
// IsValid implements driver.Validator interface
686671
// (From Go 1.15)
687672
func (mc *mysqlConn) IsValid() bool {
688673
return !mc.closed.Load()
689674
}
675+
676+
var _ driver.Validator = &mysqlConn{}

connection_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ func TestPingMarkBadConnection(t *testing.T) {
163163
netConn: nc,
164164
buf: newBuffer(nc),
165165
maxAllowedPacket: defaultMaxAllowedPacket,
166+
closech: make(chan struct{}),
166167
}
167168

168169
err := mc.Ping(context.Background())
169170

170-
if err != driver.ErrBadConn {
171-
t.Errorf("expected driver.ErrBadConn, got %#v", err)
171+
if !errors.Is(err, nc.err) {
172+
t.Errorf("expected %v, got %#v", nc.err, err)
172173
}
173174
}
174175

@@ -184,8 +185,8 @@ func TestPingErrInvalidConn(t *testing.T) {
184185

185186
err := mc.Ping(context.Background())
186187

187-
if err != ErrInvalidConn {
188-
t.Errorf("expected ErrInvalidConn, got %#v", err)
188+
if !errors.Is(err, nc.err) {
189+
t.Errorf("expected %v, got %#v", nc.err, err)
189190
}
190191
}
191192

errors.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ var (
2929
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
3030
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
3131
ErrBusyBuffer = errors.New("busy buffer")
32-
33-
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
34-
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
35-
// to trigger a resend.
36-
// See https://github.com/go-sql-driver/mysql/pull/302
37-
errBadConnNoWrite = errors.New("bad connection")
3832
)
3933

4034
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))

packets.go

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -122,32 +122,25 @@ func (mc *mysqlConn) writePacket(data []byte) error {
122122
}
123123

124124
n, err := mc.netConn.Write(data[:4+size])
125-
if err == nil && n == 4+size {
126-
mc.sequence++
127-
if size != maxPacketSize {
128-
return nil
129-
}
130-
pktLen -= size
131-
data = data[size:]
132-
continue
133-
}
134-
135-
// Handle error
136-
if err == nil { // n != len(data)
125+
if err != nil {
137126
mc.cleanup()
138-
mc.log(ErrMalformPkt)
139-
} else {
140127
if cerr := mc.canceled.Value(); cerr != nil {
141128
return cerr
142129
}
143-
if n == 0 && pktLen == len(data)-4 {
144-
// only for the first loop iteration when nothing was written yet
145-
return errBadConnNoWrite
146-
}
130+
return err
131+
}
132+
if n != size+4 {
147133
mc.cleanup()
148-
mc.log(err)
134+
return io.ErrShortWrite
149135
}
150-
return ErrInvalidConn
136+
137+
mc.sequence++
138+
if size != maxPacketSize {
139+
return nil
140+
}
141+
pktLen -= size
142+
data = data[size:]
143+
continue
151144
}
152145
}
153146

@@ -303,8 +296,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
303296
data, err := mc.buf.takeBuffer(pktLen + 4)
304297
if err != nil {
305298
// cannot take the buffer. Something must be wrong with the connection
306-
mc.log(err)
307-
return errBadConnNoWrite
299+
mc.cleanup() // Avoid repeated "busy buffer" errors.
300+
return err
308301
}
309302

310303
// ClientFlags [32 bit]
@@ -392,8 +385,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
392385
data, err := mc.buf.takeSmallBuffer(pktLen)
393386
if err != nil {
394387
// cannot take the buffer. Something must be wrong with the connection
395-
mc.log(err)
396-
return errBadConnNoWrite
388+
mc.cleanup()
389+
return err
397390
}
398391

399392
// Add the auth data [EOF]
@@ -412,8 +405,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
412405
data, err := mc.buf.takeSmallBuffer(4 + 1)
413406
if err != nil {
414407
// cannot take the buffer. Something must be wrong with the connection
415-
mc.log(err)
416-
return errBadConnNoWrite
408+
mc.cleanup()
409+
return err
417410
}
418411

419412
// Add command byte
@@ -431,8 +424,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
431424
data, err := mc.buf.takeBuffer(pktLen + 4)
432425
if err != nil {
433426
// cannot take the buffer. Something must be wrong with the connection
434-
mc.log(err)
435-
return errBadConnNoWrite
427+
mc.cleanup()
428+
return err
436429
}
437430

438431
// Add command byte
@@ -452,8 +445,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
452445
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
453446
if err != nil {
454447
// cannot take the buffer. Something must be wrong with the connection
455-
mc.log(err)
456-
return errBadConnNoWrite
448+
mc.cleanup()
449+
return err
457450
}
458451

459452
// Add command byte
@@ -994,8 +987,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
994987
}
995988
if err != nil {
996989
// cannot take the buffer. Something must be wrong with the connection
997-
mc.log(err)
998-
return errBadConnNoWrite
990+
mc.cleanup()
991+
return err
999992
}
1000993

1001994
// command [1 byte]
@@ -1193,8 +1186,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11931186
if valuesCap != cap(paramValues) {
11941187
data = append(data[:pos], paramValues...)
11951188
if err = mc.buf.store(data); err != nil {
1196-
mc.log(err)
1197-
return errBadConnNoWrite
1189+
mc.cleanup()
1190+
return err
11981191
}
11991192
}
12001193

statement.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
5757
// Send command
5858
err := stmt.writeExecutePacket(args)
5959
if err != nil {
60-
return nil, stmt.mc.markBadConn(err)
60+
return nil, err
6161
}
6262

6363
mc := stmt.mc
@@ -95,13 +95,12 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
9595

9696
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
9797
if stmt.mc.closed.Load() {
98-
stmt.mc.log(ErrInvalidConn)
9998
return nil, driver.ErrBadConn
10099
}
101100
// Send command
102101
err := stmt.writeExecutePacket(args)
103102
if err != nil {
104-
return nil, stmt.mc.markBadConn(err)
103+
return nil, err
105104
}
106105

107106
mc := stmt.mc

0 commit comments

Comments
 (0)