diff --git a/examples/readme_test.go b/examples/readme_test.go index 15da95b..6a60354 100644 --- a/examples/readme_test.go +++ b/examples/readme_test.go @@ -33,6 +33,6 @@ func ExampleNewConverter_readme() { fmt.Println(conditions) fmt.Printf("%#v\n", values) // Output: - // ((("meta"->>'map' ~* $1) OR ("meta"->>'map' ~* $2)) AND ("meta"->>'password' = $3) AND (("meta"->>'playerCount' >= $4) AND ("meta"->>'playerCount' < $5))) + // ((("meta"->>'map' ~* $1) OR ("meta"->>'map' ~* $2)) AND ("meta"->>'password' = $3) AND ((("meta"->>'playerCount')::numeric >= $4) AND (("meta"->>'playerCount')::numeric < $5))) // []interface {}{"aztec", "nuke", "", 2, 10} } diff --git a/filter/converter.go b/filter/converter.go index 8f3bec6..a17eced 100644 --- a/filter/converter.go +++ b/filter/converter.go @@ -10,11 +10,14 @@ import ( "sync" ) -var basicOperatorMap = map[string]string{ - "$gt": ">", - "$gte": ">=", - "$lt": "<", - "$lte": "<=", +var numericOperatorMap = map[string]string{ + "$gt": ">", + "$gte": ">=", + "$lt": "<", + "$lte": "<=", +} + +var textOperatorMap = map[string]string{ "$eq": "=", "$ne": "!=", "$regex": "~*", @@ -200,14 +203,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string values = append(values, v[operator]) case "$exists": // $exists only works on jsonb columns, so we need to check if the key is in the JSONB data first. - isNestedColumn := c.nestedColumn != "" - for _, exemption := range c.nestedExemptions { - if exemption == key { - isNestedColumn = false - break - } - } - if !isNestedColumn { + if !c.isNestedColumn(key) { // There is no way in Postgres to check if a column exists on a table. return "", nil, fmt.Errorf("$exists operator not supported on non-nested jsonb columns") } @@ -217,20 +213,14 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string } inner = append(inner, fmt.Sprintf("(%sjsonb_path_match(%s, 'exists($.%s)'))", neg, c.nestedColumn, key)) case "$elemMatch": - // $elemMatch needs a different implementation depending on if the column is in JSONB or not. - isNestedColumn := c.nestedColumn != "" - for _, exemption := range c.nestedExemptions { - if exemption == key { - isNestedColumn = false - break - } - } innerConditions, innerValues, err := c.convertFilter(map[string]any{c.placeholderName: v[operator]}, paramIndex) if err != nil { return "", nil, err } paramIndex += len(innerValues) - if isNestedColumn { + + // $elemMatch needs a different implementation depending on if the column is in JSONB or not. + if c.isNestedColumn(key) { // This will for example become: // // EXISTS (SELECT 1 FROM jsonb_array_elements("meta"->'foo') AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1)) @@ -247,11 +237,27 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string values = append(values, innerValues...) default: value := v[operator] - op, ok := basicOperatorMap[operator] + isNumericOperator := false + op, ok := textOperatorMap[operator] if !ok { - return "", nil, fmt.Errorf("unknown operator: %s", operator) + op, ok = numericOperatorMap[operator] + if !ok { + return "", nil, fmt.Errorf("unknown operator: %s", operator) + } + isNumericOperator = true + } + + // Prevent cryptic errors like: + // unexpected error: sql: converting argument $1 type: unsupported type []interface {}, a slice of interface + if !isScalar(value) { + return "", nil, fmt.Errorf("invalid comparison value (must be a primitive): %v", value) + } + + if isNumericOperator && isNumeric(value) && c.isNestedColumn(key) { + inner = append(inner, fmt.Sprintf("((%s)::numeric %s $%d)", c.columnName(key), op, paramIndex)) + } else { + inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex)) } - inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex)) paramIndex++ values = append(values, value) } @@ -277,6 +283,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string conditions = append(conditions, fmt.Sprintf("(%s IS NULL)", c.columnName(key))) } default: + // Prevent cryptic errors like: + // unexpected error: sql: converting argument $1 type: unsupported type []interface {}, a slice of interface if !isScalar(value) { return "", nil, fmt.Errorf("invalid comparison value (must be a primitive): %v", value) } @@ -308,3 +316,15 @@ func (c *Converter) columnName(column string) string { } return fmt.Sprintf(`%q->>'%s'`, c.nestedColumn, column) } + +func (c *Converter) isNestedColumn(column string) bool { + if c.nestedColumn == "" { + return false + } + for _, exemption := range c.nestedExemptions { + if exemption == column { + return false + } + } + return true +} diff --git a/filter/converter_test.go b/filter/converter_test.go index 4dfba5a..d7b182a 100644 --- a/filter/converter_test.go +++ b/filter/converter_test.go @@ -358,6 +358,30 @@ func TestConverter_Convert(t *testing.T) { []any{float64(18)}, nil, }, + { + "numeric comparison bug with jsonb column", + filter.WithNestedJSONB("meta"), + `{"foo": {"$gt": 0}}`, + `(("meta"->>'foo')::numeric > $1)`, + []any{float64(0)}, + nil, + }, + { + "numeric comparison against null with jsonb column", + filter.WithNestedJSONB("meta"), + `{"foo": {"$gt": null}}`, + `("meta"->>'foo' > $1)`, + []any{nil}, + nil, + }, + { + "compare with non scalar", + nil, + `{"name": {"$eq": [1, 2]}}`, + ``, + nil, + fmt.Errorf("invalid comparison value (must be a primitive): [1 2]"), + }, } for _, tt := range tests { diff --git a/filter/util.go b/filter/util.go index bc762fe..9634b01 100644 --- a/filter/util.go +++ b/filter/util.go @@ -1,5 +1,12 @@ package filter +func isNumeric(v any) bool { + // json.Unmarshal returns float64 for all numbers + // so we only need to check for float64. + _, ok := v.(float64) + return ok +} + func isScalar(v any) bool { if v == nil { return true diff --git a/integration/postgres_test.go b/integration/postgres_test.go index acba5d1..d2a72c1 100644 --- a/integration/postgres_test.go +++ b/integration/postgres_test.go @@ -286,10 +286,16 @@ func TestIntegration_BasicOperators(t *testing.T) { nil, }, { - `invalid value`, + `invalid value type int`, `{"level": "town1"}`, // Level is an integer column, but the value is a string. nil, - errors.New("pq: invalid input syntax for type integer: \"town1\""), + errors.New(`pq: invalid input syntax for type integer: "town1"`), + }, + { + `invalid value type string`, + `{"name": 123}`, // Name is a string column, but the value is an integer. + []int{}, + nil, }, { `empty object`, @@ -381,6 +387,18 @@ func TestIntegration_BasicOperators(t *testing.T) { []int{3}, nil, }, + { + "$lt bug with jsonb column", + `{"guild_id": {"$lt": 100}}`, + []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + }, + { + "$lt with null and jsonb column", + `{"guild_id": {"$lt": null}}`, + []int{}, + nil, + }, } for _, tt := range tests {