Skip to content

Commit 4e25c3e

Browse files
committed
database/sql: make RawBytes safely usable with contexts
sql.RawBytes was added the very first Go release, Go 1. Its docs say: > RawBytes is a byte slice that holds a reference to memory owned by > the database itself. After a Scan into a RawBytes, the slice is only > valid until the next call to Next, Scan, or Close. That "only valid until the next call" bit was true at the time, until contexts were added to database/sql in Go 1.8. In the past ~dozen releases it's been unsafe to use QueryContext with a context that might become Done to get an *sql.Rows that's scanning into a RawBytes. The Scan can succeed, but then while the caller's reading the memory, a database/sql-managed goroutine can see the context becoming done and call Close on the database/sql/driver and make the caller's view of the RawBytes memory no longer valid, introducing races, crashes, or database corruption. See golang#60304 and golang#53970 for details. This change does the minimal surgery on database/sql to make it safe again: Rows.Scan was already acquiring a mutex to check whether the rows had been closed, so this change make Rows.Scan notice whether *RawBytes was used and, if so, doesn't release the mutex on exit before returning. That mean it's still locked while the user code operates on the RawBytes memory and the concurrent context-watching goroutine to close the database still runs, but if it fires, it then gets blocked on the mutex until the next call to a Rows method (Next, NextResultSet, Err, Close). Updates golang#60304 Updates golang#53970 (earlier one I'd missed) Change-Id: Ie41c0c6f32c24887b2f53ec3686c2aab73a1bfff Reviewed-on: https://go-review.googlesource.com/c/go/+/497675 TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Ian Lance Taylor <[email protected]> Run-TryBot: Brad Fitzpatrick <[email protected]> Auto-Submit: Ian Lance Taylor <[email protected]> Reviewed-by: Russ Cox <[email protected]>
1 parent 30a936a commit 4e25c3e

File tree

3 files changed

+141
-2
lines changed

3 files changed

+141
-2
lines changed

src/database/sql/fakedb_test.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"strconv"
1616
"strings"
1717
"sync"
18+
"sync/atomic"
1819
"testing"
1920
"time"
2021
)
@@ -90,6 +91,8 @@ func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
9091
type fakeDB struct {
9192
name string
9293

94+
useRawBytes atomic.Bool
95+
9396
mu sync.Mutex
9497
tables map[string]*table
9598
badConn bool
@@ -697,6 +700,8 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
697700
switch cmd {
698701
case "WIPE":
699702
// Nothing
703+
case "USE_RAWBYTES":
704+
c.db.useRawBytes.Store(true)
700705
case "SELECT":
701706
stmt, err = c.prepareSelect(stmt, parts)
702707
case "CREATE":
@@ -800,6 +805,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
800805
case "WIPE":
801806
db.wipe()
802807
return driver.ResultNoRows, nil
808+
case "USE_RAWBYTES":
809+
s.c.db.useRawBytes.Store(true)
810+
return driver.ResultNoRows, nil
803811
case "CREATE":
804812
if err := db.createTable(s.table, s.colName, s.colType); err != nil {
805813
return nil, err
@@ -929,6 +937,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
929937
txStatus = "transaction"
930938
}
931939
cursor := &rowsCursor{
940+
db: s.c.db,
932941
parentMem: s.c,
933942
posRow: -1,
934943
rows: [][]*row{
@@ -1025,6 +1034,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
10251034
}
10261035

10271036
cursor := &rowsCursor{
1037+
db: s.c.db,
10281038
parentMem: s.c,
10291039
posRow: -1,
10301040
rows: setMRows,
@@ -1067,6 +1077,7 @@ func (tx *fakeTx) Rollback() error {
10671077
}
10681078

10691079
type rowsCursor struct {
1080+
db *fakeDB
10701081
parentMem memToucher
10711082
cols [][]string
10721083
colType [][]string
@@ -1141,7 +1152,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
11411152
// messing up conversions or doing them differently.
11421153
dest[i] = v
11431154

1144-
if bs, ok := v.([]byte); ok {
1155+
if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
11451156
if rc.bytesClone == nil {
11461157
rc.bytesClone = make(map[*byte][]byte)
11471158
}

src/database/sql/sql.go

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2893,6 +2893,8 @@ type Rows struct {
28932893
cancel func() // called when Rows is closed, may be nil.
28942894
closeStmt *driverStmt // if non-nil, statement to Close on close
28952895

2896+
contextDone atomic.Pointer[error] // error that awaitDone saw; set before close attempt
2897+
28962898
// closemu prevents Rows from closing while there
28972899
// is an active streaming result. It is held for read during non-close operations
28982900
// and exclusively during close.
@@ -2905,6 +2907,15 @@ type Rows struct {
29052907
// lastcols is only used in Scan, Next, and NextResultSet which are expected
29062908
// not to be called concurrently.
29072909
lastcols []driver.Value
2910+
2911+
// closemuScanHold is whether the previous call to Scan kept closemu RLock'ed
2912+
// without unlocking it. It does that when the user passes a *RawBytes scan
2913+
// target. In that case, we need to prevent awaitDone from closing the Rows
2914+
// while the user's still using the memory. See go.dev/issue/60304.
2915+
//
2916+
// It is only used by Scan, Next, and NextResultSet which are expected
2917+
// not to be called concurrently.
2918+
closemuScanHold bool
29082919
}
29092920

29102921
// lasterrOrErrLocked returns either lasterr or the provided err.
@@ -2942,7 +2953,11 @@ func (rs *Rows) awaitDone(ctx, txctx context.Context) {
29422953
}
29432954
select {
29442955
case <-ctx.Done():
2956+
err := ctx.Err()
2957+
rs.contextDone.Store(&err)
29452958
case <-txctxDone:
2959+
err := txctx.Err()
2960+
rs.contextDone.Store(&err)
29462961
}
29472962
rs.close(ctx.Err())
29482963
}
@@ -2954,6 +2969,15 @@ func (rs *Rows) awaitDone(ctx, txctx context.Context) {
29542969
//
29552970
// Every call to Scan, even the first one, must be preceded by a call to Next.
29562971
func (rs *Rows) Next() bool {
2972+
// If the user's calling Next, they're done with their previous row's Scan
2973+
// results (any RawBytes memory), so we can release the read lock that would
2974+
// be preventing awaitDone from calling close.
2975+
rs.closemuRUnlockIfHeldByScan()
2976+
2977+
if rs.contextDone.Load() != nil {
2978+
return false
2979+
}
2980+
29572981
var doClose, ok bool
29582982
withLock(rs.closemu.RLocker(), func() {
29592983
doClose, ok = rs.nextLocked()
@@ -3008,6 +3032,11 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
30083032
// scanning. If there are further result sets they may not have rows in the result
30093033
// set.
30103034
func (rs *Rows) NextResultSet() bool {
3035+
// If the user's calling NextResultSet, they're done with their previous
3036+
// row's Scan results (any RawBytes memory), so we can release the read lock
3037+
// that would be preventing awaitDone from calling close.
3038+
rs.closemuRUnlockIfHeldByScan()
3039+
30113040
var doClose bool
30123041
defer func() {
30133042
if doClose {
@@ -3044,6 +3073,10 @@ func (rs *Rows) NextResultSet() bool {
30443073
// Err returns the error, if any, that was encountered during iteration.
30453074
// Err may be called after an explicit or implicit Close.
30463075
func (rs *Rows) Err() error {
3076+
if errp := rs.contextDone.Load(); errp != nil {
3077+
return *errp
3078+
}
3079+
30473080
rs.closemu.RLock()
30483081
defer rs.closemu.RUnlock()
30493082
return rs.lasterrOrErrLocked(nil)
@@ -3237,6 +3270,11 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
32373270
// If any of the first arguments implementing Scanner returns an error,
32383271
// that error will be wrapped in the returned error.
32393272
func (rs *Rows) Scan(dest ...any) error {
3273+
if rs.closemuScanHold {
3274+
// This should only be possible if the user calls Scan twice in a row
3275+
// without calling Next.
3276+
return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
3277+
}
32403278
rs.closemu.RLock()
32413279

32423280
if rs.lasterr != nil && rs.lasterr != io.EOF {
@@ -3248,23 +3286,50 @@ func (rs *Rows) Scan(dest ...any) error {
32483286
rs.closemu.RUnlock()
32493287
return err
32503288
}
3251-
rs.closemu.RUnlock()
3289+
3290+
if scanArgsContainRawBytes(dest) {
3291+
rs.closemuScanHold = true
3292+
} else {
3293+
rs.closemu.RUnlock()
3294+
}
32523295

32533296
if rs.lastcols == nil {
3297+
rs.closemuRUnlockIfHeldByScan()
32543298
return errors.New("sql: Scan called without calling Next")
32553299
}
32563300
if len(dest) != len(rs.lastcols) {
3301+
rs.closemuRUnlockIfHeldByScan()
32573302
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
32583303
}
3304+
32593305
for i, sv := range rs.lastcols {
32603306
err := convertAssignRows(dest[i], sv, rs)
32613307
if err != nil {
3308+
rs.closemuRUnlockIfHeldByScan()
32623309
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
32633310
}
32643311
}
32653312
return nil
32663313
}
32673314

3315+
// closemuRUnlockIfHeldByScan releases any closemu.RLock held open by a previous
3316+
// call to Scan with *RawBytes.
3317+
func (rs *Rows) closemuRUnlockIfHeldByScan() {
3318+
if rs.closemuScanHold {
3319+
rs.closemuScanHold = false
3320+
rs.closemu.RUnlock()
3321+
}
3322+
}
3323+
3324+
func scanArgsContainRawBytes(args []any) bool {
3325+
for _, a := range args {
3326+
if _, ok := a.(*RawBytes); ok {
3327+
return true
3328+
}
3329+
}
3330+
return false
3331+
}
3332+
32683333
// rowsCloseHook returns a function so tests may install the
32693334
// hook through a test only mutex.
32703335
var rowsCloseHook = func() func(*Rows, *error) { return nil }
@@ -3274,6 +3339,11 @@ var rowsCloseHook = func() func(*Rows, *error) { return nil }
32743339
// the Rows are closed automatically and it will suffice to check the
32753340
// result of Err. Close is idempotent and does not affect the result of Err.
32763341
func (rs *Rows) Close() error {
3342+
// If the user's calling Close, they're done with their previous row's Scan
3343+
// results (any RawBytes memory), so we can release the read lock that would
3344+
// be preventing awaitDone from calling the unexported close before we do so.
3345+
rs.closemuRUnlockIfHeldByScan()
3346+
32773347
return rs.close(nil)
32783348
}
32793349

src/database/sql/sql_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4385,6 +4385,64 @@ func TestRowsScanProperlyWrapsErrors(t *testing.T) {
43854385
}
43864386
}
43874387

4388+
// From go.dev/issue/60304
4389+
func TestContextCancelDuringRawBytesScan(t *testing.T) {
4390+
db := newTestDB(t, "people")
4391+
defer closeDB(t, db)
4392+
4393+
if _, err := db.Exec("USE_RAWBYTES"); err != nil {
4394+
t.Fatal(err)
4395+
}
4396+
4397+
ctx, cancel := context.WithCancel(context.Background())
4398+
defer cancel()
4399+
4400+
r, err := db.QueryContext(ctx, "SELECT|people|name|")
4401+
if err != nil {
4402+
t.Fatal(err)
4403+
}
4404+
numRows := 0
4405+
var sink byte
4406+
for r.Next() {
4407+
numRows++
4408+
var s RawBytes
4409+
err = r.Scan(&s)
4410+
if !r.closemuScanHold {
4411+
t.Errorf("expected closemu to be held")
4412+
}
4413+
if err != nil {
4414+
t.Fatal(err)
4415+
}
4416+
t.Logf("read %q", s)
4417+
if numRows == 2 {
4418+
cancel() // invalidate the context, which used to call close asynchronously
4419+
}
4420+
for _, b := range s { // some operation reading from the raw memory
4421+
sink += b
4422+
}
4423+
}
4424+
if r.closemuScanHold {
4425+
t.Errorf("closemu held; should not be")
4426+
}
4427+
4428+
// There are 3 rows. We canceled after reading 2 so we expect either
4429+
// 2 or 3 depending on how the awaitDone goroutine schedules.
4430+
switch numRows {
4431+
case 0, 1:
4432+
t.Errorf("got %d rows; want 2+", numRows)
4433+
case 2:
4434+
if err := r.Err(); err != context.Canceled {
4435+
t.Errorf("unexpected error: %v (%T)", err, err)
4436+
}
4437+
default:
4438+
// Made it to the end. This is rare, but fine. Permit it.
4439+
}
4440+
4441+
if err := r.Close(); err != nil {
4442+
t.Fatal(err)
4443+
}
4444+
}
4445+
43884446
// badConn implements a bad driver.Conn, for TestBadDriver.
43894447
// The Exec method panics.
43904448
type badConn struct{}

0 commit comments

Comments
 (0)