@@ -442,24 +442,15 @@ func (r *limitedCommandResult) moreResultsNeeded(ctx context.Context) error {
442
442
}
443
443
switch c := cmd .(type ) {
444
444
case sql.DeletePreparedStmt :
445
- // The client wants to close a portal or statement. We
446
- // support the case where it is exactly this
447
- // portal. This is done by closing the portal in
448
- // the same way implicit transactions do, but also
449
- // rewinding the stmtBuf to still point to the portal
450
- // close so that the state machine can do its part of
451
- // the cleanup. We are in effect peeking to see if the
445
+ // The client wants to close a portal or statement. We support the case
446
+ // where it is exactly this portal. We are in effect peeking to see if the
452
447
// next message is a delete portal.
453
448
if c .Type != pgwirebase .PreparePortal || c .Name != r .portalName {
454
449
telemetry .Inc (sqltelemetry .InterleavedPortalRequestCounter )
455
450
return errors .WithDetail (sql .ErrLimitedResultNotSupported ,
456
451
"cannot close a portal while a different one is open" )
457
452
}
458
- r .typ = noCompletionMsg
459
- // Rewind to before the delete so the AdvanceOne in
460
- // connExecutor.execCmd ends up back on it.
461
- r .conn .stmtBuf .Rewind (ctx , prevPos )
462
- return sql .ErrLimitedResultClosed
453
+ return r .rewindAndClosePortal (ctx , prevPos )
463
454
case sql.ExecPortal :
464
455
// The happy case: the client wants more rows from the portal.
465
456
if c .Name != r .portalName {
@@ -483,6 +474,13 @@ func (r *limitedCommandResult) moreResultsNeeded(ctx context.Context) error {
483
474
return err
484
475
}
485
476
default :
477
+ // If the portal is immediately followed by a COMMIT, we can proceed and
478
+ // let the portal be destroyed at the end of the transaction.
479
+ if isCommit , err := r .isCommit (); err != nil {
480
+ return err
481
+ } else if isCommit {
482
+ return r .rewindAndClosePortal (ctx , prevPos )
483
+ }
486
484
// We got some other message, but we only support executing to completion.
487
485
telemetry .Inc (sqltelemetry .InterleavedPortalRequestCounter )
488
486
return errors .WithDetail (sql .ErrLimitedResultNotSupported ,
@@ -491,3 +489,77 @@ func (r *limitedCommandResult) moreResultsNeeded(ctx context.Context) error {
491
489
prevPos = curPos
492
490
}
493
491
}
492
+
493
+ // isCommit checks if the statement buffer has a COMMIT at the current
494
+ // position. It may either be (1) a COMMIT in the simple protocol, or (2) a
495
+ // Parse/Bind/Execute sequence for a COMMIT query.
496
+ func (r * limitedCommandResult ) isCommit () (bool , error ) {
497
+ cmd , _ , err := r .conn .stmtBuf .CurCmd ()
498
+ if err != nil {
499
+ return false , err
500
+ }
501
+ // Case 1: Check if cmd is a simple COMMIT statement.
502
+ if execStmt , ok := cmd .(sql.ExecStmt ); ok {
503
+ if _ , isCommit := execStmt .AST .(* tree.CommitTransaction ); isCommit {
504
+ return true , nil
505
+ }
506
+ }
507
+
508
+ commitStmtName := ""
509
+ commitPortalName := ""
510
+ // Case 2a: Check if cmd is a prepared COMMIT statement.
511
+ if prepareStmt , ok := cmd .(sql.PrepareStmt ); ok {
512
+ if _ , isCommit := prepareStmt .AST .(* tree.CommitTransaction ); isCommit {
513
+ commitStmtName = prepareStmt .Name
514
+ } else {
515
+ return false , nil
516
+ }
517
+ } else {
518
+ return false , nil
519
+ }
520
+
521
+ r .conn .stmtBuf .AdvanceOne ()
522
+ cmd , _ , err = r .conn .stmtBuf .CurCmd ()
523
+ if err != nil {
524
+ return false , err
525
+ }
526
+ // Case 2b: The next cmd must be a bind command.
527
+ if bindStmt , ok := cmd .(sql.BindStmt ); ok {
528
+ // This bind command must be for the COMMIT statement that we just saw.
529
+ if bindStmt .PreparedStatementName == commitStmtName {
530
+ commitPortalName = bindStmt .PortalName
531
+ } else {
532
+ return false , nil
533
+ }
534
+ } else {
535
+ return false , nil
536
+ }
537
+
538
+ r .conn .stmtBuf .AdvanceOne ()
539
+ cmd , _ , err = r .conn .stmtBuf .CurCmd ()
540
+ if err != nil {
541
+ return false , err
542
+ }
543
+ // Case 2c: The next cmd must be an exec portal command.
544
+ if execPortal , ok := cmd .(sql.ExecPortal ); ok {
545
+ // This exec command must be for the portal that was just bound.
546
+ if execPortal .Name == commitPortalName {
547
+ return true , nil
548
+ }
549
+ }
550
+ return false , nil
551
+ }
552
+
553
+ // rewindAndClosePortal closes the portal in the same way implicit transactions
554
+ // do, but also rewinds the stmtBuf to still point to the portal close so that
555
+ // the state machine can do its part of the cleanup.
556
+ func (r * limitedCommandResult ) rewindAndClosePortal (
557
+ ctx context.Context , rewindTo sql.CmdPos ,
558
+ ) error {
559
+ // Don't send an CommandComplete for the portal; it got suspended.
560
+ r .typ = noCompletionMsg
561
+ // Rewind to before the delete so the AdvanceOne in connExecutor.execCmd ends
562
+ // up back on it.
563
+ r .conn .stmtBuf .Rewind (ctx , rewindTo )
564
+ return sql .ErrLimitedResultClosed
565
+ }
0 commit comments