@@ -291,7 +291,7 @@ var varGen = map[string]sessionVar{
291
291
},
292
292
// See https://www.postgresql.org/docs/9.3/static/runtime-config-client.html#GUC-DEFAULT-TRANSACTION-READ-ONLY
293
293
`default_transaction_read_only` : {
294
- GetStringVal : makeBoolGetStringValFn ("default_transaction_read_only" ),
294
+ GetStringVal : makePostgresBoolGetStringValFn ("default_transaction_read_only" ),
295
295
Set : func (_ context.Context , m * sessionDataMutator , s string ) error {
296
296
b , err := parsePostgresBool (s )
297
297
if err != nil {
@@ -326,7 +326,7 @@ var varGen = map[string]sessionVar{
326
326
327
327
// CockroachDB extension.
328
328
`experimental_force_split_at` : {
329
- GetStringVal : makeBoolGetStringValFn (`experimental_force_split_at` ),
329
+ GetStringVal : makePostgresBoolGetStringValFn (`experimental_force_split_at` ),
330
330
Set : func (_ context.Context , m * sessionDataMutator , s string ) error {
331
331
b , err := parsePostgresBool (s )
332
332
if err != nil {
@@ -343,7 +343,7 @@ var varGen = map[string]sessionVar{
343
343
344
344
// CockroachDB extension.
345
345
`enable_zigzag_join` : {
346
- GetStringVal : makeBoolGetStringValFn (`enable_zigzag_join` ),
346
+ GetStringVal : makePostgresBoolGetStringValFn (`enable_zigzag_join` ),
347
347
Set : func (_ context.Context , m * sessionDataMutator , s string ) error {
348
348
b , err := parsePostgresBool (s )
349
349
if err != nil {
@@ -445,7 +445,7 @@ var varGen = map[string]sessionVar{
445
445
446
446
// CockroachDB extension.
447
447
`experimental_optimizer_foreign_keys` : {
448
- GetStringVal : makeBoolGetStringValFn (`experimental_optimizer_foreign_keys` ),
448
+ GetStringVal : makePostgresBoolGetStringValFn (`experimental_optimizer_foreign_keys` ),
449
449
Set : func (_ context.Context , m * sessionDataMutator , s string ) error {
450
450
b , err := parsePostgresBool (s )
451
451
if err != nil {
@@ -513,7 +513,7 @@ var varGen = map[string]sessionVar{
513
513
Get : func (evalCtx * extendedEvalContext ) string {
514
514
return formatBoolAsPostgresSetting (evalCtx .SessionData .ForceSavepointRestart )
515
515
},
516
- GetStringVal : makeBoolGetStringValFn ("force_savepoint_restart" ),
516
+ GetStringVal : makePostgresBoolGetStringValFn ("force_savepoint_restart" ),
517
517
Set : func (_ context.Context , m * sessionDataMutator , val string ) error {
518
518
b , err := parsePostgresBool (val )
519
519
if err != nil {
@@ -577,7 +577,7 @@ var varGen = map[string]sessionVar{
577
577
Get : func (evalCtx * extendedEvalContext ) string {
578
578
return formatBoolAsPostgresSetting (evalCtx .SessionData .SafeUpdates )
579
579
},
580
- GetStringVal : makeBoolGetStringValFn ("sql_safe_updates" ),
580
+ GetStringVal : makePostgresBoolGetStringValFn ("sql_safe_updates" ),
581
581
Set : func (_ context.Context , m * sessionDataMutator , s string ) error {
582
582
b , err := parsePostgresBool (s )
583
583
if err != nil {
@@ -750,7 +750,7 @@ var varGen = map[string]sessionVar{
750
750
751
751
// See https://www.postgresql.org/docs/10/static/hot-standby.html#HOT-STANDBY-USERS
752
752
`transaction_read_only` : {
753
- GetStringVal : makeBoolGetStringValFn ("transaction_read_only" ),
753
+ GetStringVal : makePostgresBoolGetStringValFn ("transaction_read_only" ),
754
754
Set : func (_ context.Context , m * sessionDataMutator , s string ) error {
755
755
b , err := parsePostgresBool (s )
756
756
if err != nil {
@@ -824,11 +824,23 @@ func init() {
824
824
}
825
825
}
826
826
827
- func makeBoolGetStringValFn (varName string ) getStringValFn {
827
+ // makePostgresBoolGetStringValFn returns a function that evaluates and returns
828
+ // a string representation of the first argument value.
829
+ func makePostgresBoolGetStringValFn (varName string ) getStringValFn {
828
830
return func (
829
831
ctx context.Context , evalCtx * extendedEvalContext , values []tree.TypedExpr ,
830
832
) (string , error ) {
831
- s , err := getSingleBool (varName , evalCtx , values )
833
+ if len (values ) != 1 {
834
+ return "" , newSingleArgVarError (varName )
835
+ }
836
+ val , err := values [0 ].Eval (& evalCtx .EvalContext )
837
+ if err != nil {
838
+ return "" , err
839
+ }
840
+ if s , ok := val .(* tree.DString ); ok {
841
+ return string (* s ), nil
842
+ }
843
+ s , err := getSingleBool (varName , val )
832
844
if err != nil {
833
845
return "" , err
834
846
}
@@ -937,22 +949,15 @@ var varNames = func() []string {
937
949
return res
938
950
}()
939
951
940
- func getSingleBool (
941
- name string , evalCtx * extendedEvalContext , values []tree.TypedExpr ,
942
- ) (* tree.DBool , error ) {
943
- if len (values ) != 1 {
944
- return nil , newSingleArgVarError (name )
945
- }
946
- val , err := values [0 ].Eval (& evalCtx .EvalContext )
947
- if err != nil {
948
- return nil , err
949
- }
952
+ // getSingleBool returns the boolean if the input Datum is a DBool,
953
+ // and returns a detailed error message if not.
954
+ func getSingleBool (name string , val tree.Datum ) (* tree.DBool , error ) {
950
955
b , ok := val .(* tree.DBool )
951
956
if ! ok {
952
- err = pgerror .Newf (pgcode .InvalidParameterValue ,
957
+ err : = pgerror .Newf (pgcode .InvalidParameterValue ,
953
958
"parameter %q requires a Boolean value" , name )
954
959
err = errors .WithDetailf (err ,
955
- "%s is a %s" , values [ 0 ] , errors .Safe (val .ResolvedType ()))
960
+ "%s is a %s" , val , errors .Safe (val .ResolvedType ()))
956
961
return nil , err
957
962
}
958
963
return b , nil
0 commit comments