From 592c1d6a77ed4e58d7f8e6eb394da00190c45feb Mon Sep 17 00:00:00 2001 From: Oliver Lambson Date: Tue, 11 Mar 2025 16:34:04 -0700 Subject: [PATCH] feat: add overrides config option --- README.md | 44 +++++++++ internal/config.go | 25 +++-- .../testdata/emit_pydantic_models/sqlc.yaml | 2 +- .../endtoend/testdata/emit_str_enum/sqlc.yaml | 2 +- .../testdata/emit_type_overrides/db/models.py | 11 +++ .../testdata/emit_type_overrides/db/query.py | 92 +++++++++++++++++++ .../emit_type_overrides/my_lib/__init__.py | 0 .../emit_type_overrides/my_lib/models.py | 7 ++ .../testdata/emit_type_overrides/query.sql | 12 +++ .../testdata/emit_type_overrides/schema.sql | 4 + .../testdata/emit_type_overrides/sqlc.yaml | 22 +++++ .../endtoend/testdata/exec_result/sqlc.yaml | 2 +- .../endtoend/testdata/exec_rows/sqlc.yaml | 2 +- .../inflection_exclude_table_names/sqlc.yaml | 2 +- .../query_parameter_limit_two/sqlc.yaml | 2 +- .../query_parameter_limit_undefined/sqlc.yaml | 2 +- .../query_parameter_limit_zero/sqlc.yaml | 2 +- .../query_parameter_no_limit/sqlc.yaml | 2 +- internal/gen.go | 34 +++++++ internal/imports.go | 28 ++++++ 20 files changed, 279 insertions(+), 18 deletions(-) create mode 100644 internal/endtoend/testdata/emit_type_overrides/db/models.py create mode 100644 internal/endtoend/testdata/emit_type_overrides/db/query.py create mode 100644 internal/endtoend/testdata/emit_type_overrides/my_lib/__init__.py create mode 100644 internal/endtoend/testdata/emit_type_overrides/my_lib/models.py create mode 100644 internal/endtoend/testdata/emit_type_overrides/query.sql create mode 100644 internal/endtoend/testdata/emit_type_overrides/schema.sql create mode 100644 internal/endtoend/testdata/emit_type_overrides/sqlc.yaml diff --git a/README.md b/README.md index c9f2531..5f7ecee 100644 --- a/README.md +++ b/README.md @@ -76,3 +76,47 @@ class Status(str, enum.Enum): OPEN = "op!en" CLOSED = "clo@sed" ``` + +### Override Column Types + +Option: `overrides` + +You can override the SQL to Python type mapping for specific columns using the `overrides` option. This is useful for columns with JSON data or other custom types. + +Example configuration: + +```yaml +options: + package: authors + emit_pydantic_models: true + overrides: + - column: "some_table.payload" + py_import: "my_lib.models" + py_type: "Payload" +``` + +This will: +1. Override the column `payload` in `some_table` to use the type `Payload` +2. Add an import for `my_lib.models` to the models file + +Example output: + +```python +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.28.0 + +import datetime +import pydantic +from typing import Any + +import my_lib.models + + +class SomeTable(pydantic.BaseModel): + id: int + created_at: datetime.datetime + payload: my_lib.models.Payload +``` + +This is similar to the [overrides functionality in the Go version of sqlc](https://docs.sqlc.dev/en/stable/howto/overrides.html#overriding-types). diff --git a/internal/config.go b/internal/config.go index 1a8a565..e78112c 100644 --- a/internal/config.go +++ b/internal/config.go @@ -1,13 +1,20 @@ package python +type OverrideColumn struct { + Column string `json:"column"` + PyType string `json:"py_type"` + PyImport string `json:"py_import"` +} + type Config struct { - EmitExactTableNames bool `json:"emit_exact_table_names"` - EmitSyncQuerier bool `json:"emit_sync_querier"` - EmitAsyncQuerier bool `json:"emit_async_querier"` - Package string `json:"package"` - Out string `json:"out"` - EmitPydanticModels bool `json:"emit_pydantic_models"` - EmitStrEnum bool `json:"emit_str_enum"` - QueryParameterLimit *int32 `json:"query_parameter_limit"` - InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + EmitExactTableNames bool `json:"emit_exact_table_names"` + EmitSyncQuerier bool `json:"emit_sync_querier"` + EmitAsyncQuerier bool `json:"emit_async_querier"` + Package string `json:"package"` + Out string `json:"out"` + EmitPydanticModels bool `json:"emit_pydantic_models"` + EmitStrEnum bool `json:"emit_str_enum"` + QueryParameterLimit *int32 `json:"query_parameter_limit"` + InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + Overrides []OverrideColumn `json:"overrides"` } diff --git a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml index beae200..62ec488 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml +++ b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml index 04e3feb..56fe8bf 100644 --- a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml +++ b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/emit_type_overrides/db/models.py b/internal/endtoend/testdata/emit_type_overrides/db/models.py new file mode 100644 index 0000000..1decb3d --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/db/models.py @@ -0,0 +1,11 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.28.0 +import pydantic + +import my_lib.models + + +class Book(pydantic.BaseModel): + id: int + payload: my_lib.models.Payload diff --git a/internal/endtoend/testdata/emit_type_overrides/db/query.py b/internal/endtoend/testdata/emit_type_overrides/db/query.py new file mode 100644 index 0000000..0486a35 --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/db/query.py @@ -0,0 +1,92 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.28.0 +# source: query.sql +from typing import AsyncIterator, Iterator, Optional + +import my_lib.models +import sqlalchemy +import sqlalchemy.ext.asyncio + +from db import models + + +CREATE_BOOK = """-- name: create_book \\:one +INSERT INTO books (payload) +VALUES (:p1) +RETURNING id, payload +""" + + +GET_BOOK = """-- name: get_book \\:one +SELECT id, payload FROM books +WHERE id = :p1 LIMIT 1 +""" + + +LIST_BOOKS = """-- name: list_books \\:many +SELECT id, payload FROM books +ORDER BY id +""" + + +class Querier: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn + + def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]: + row = self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload}).first() + if row is None: + return None + return models.Book( + id=row[0], + payload=row[1], + ) + + def get_book(self, *, id: int) -> Optional[models.Book]: + row = self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id}).first() + if row is None: + return None + return models.Book( + id=row[0], + payload=row[1], + ) + + def list_books(self) -> Iterator[models.Book]: + result = self._conn.execute(sqlalchemy.text(LIST_BOOKS)) + for row in result: + yield models.Book( + id=row[0], + payload=row[1], + ) + + +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload})).first() + if row is None: + return None + return models.Book( + id=row[0], + payload=row[1], + ) + + async def get_book(self, *, id: int) -> Optional[models.Book]: + row = (await self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id})).first() + if row is None: + return None + return models.Book( + id=row[0], + payload=row[1], + ) + + async def list_books(self) -> AsyncIterator[models.Book]: + result = await self._conn.stream(sqlalchemy.text(LIST_BOOKS)) + async for row in result: + yield models.Book( + id=row[0], + payload=row[1], + ) diff --git a/internal/endtoend/testdata/emit_type_overrides/my_lib/__init__.py b/internal/endtoend/testdata/emit_type_overrides/my_lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/internal/endtoend/testdata/emit_type_overrides/my_lib/models.py b/internal/endtoend/testdata/emit_type_overrides/my_lib/models.py new file mode 100644 index 0000000..1f1a052 --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/my_lib/models.py @@ -0,0 +1,7 @@ +from datetime import date + +from pydantic import BaseModel + +class Payload(BaseModel): + name: str + release_date: date diff --git a/internal/endtoend/testdata/emit_type_overrides/query.sql b/internal/endtoend/testdata/emit_type_overrides/query.sql new file mode 100644 index 0000000..ab1a3c1 --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/query.sql @@ -0,0 +1,12 @@ +-- name: GetBook :one +SELECT * FROM books +WHERE id = $1 LIMIT 1; + +-- name: ListBooks :many +SELECT * FROM books +ORDER BY id; + +-- name: CreateBook :one +INSERT INTO books (payload) +VALUES (sqlc.arg(payload)) +RETURNING *; diff --git a/internal/endtoend/testdata/emit_type_overrides/schema.sql b/internal/endtoend/testdata/emit_type_overrides/schema.sql new file mode 100644 index 0000000..51997ea --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE books ( + id SERIAL PRIMARY KEY, + payload JSONB NOT NULL +); diff --git a/internal/endtoend/testdata/emit_type_overrides/sqlc.yaml b/internal/endtoend/testdata/emit_type_overrides/sqlc.yaml new file mode 100644 index 0000000..e70c41a --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/sqlc.yaml @@ -0,0 +1,22 @@ +version: "2" +plugins: + - name: py + wasm: + url: file://../../../../bin/sqlc-gen-python.wasm + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" +sql: + - schema: schema.sql + queries: query.sql + engine: postgresql + codegen: + - plugin: py + out: db + options: + package: db + emit_pydantic_models: true + emit_sync_querier: true + emit_async_querier: true + overrides: + - column: "books.payload" + py_import: "my_lib.models" + py_type: "Payload" diff --git a/internal/endtoend/testdata/exec_result/sqlc.yaml b/internal/endtoend/testdata/exec_result/sqlc.yaml index ddffc83..e7fe6ff 100644 --- a/internal/endtoend/testdata/exec_result/sqlc.yaml +++ b/internal/endtoend/testdata/exec_result/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/exec_rows/sqlc.yaml b/internal/endtoend/testdata/exec_rows/sqlc.yaml index ddffc83..e7fe6ff 100644 --- a/internal/endtoend/testdata/exec_rows/sqlc.yaml +++ b/internal/endtoend/testdata/exec_rows/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml index efbb150..030d33e 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml +++ b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml index 336bca7..018e2db 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml index c20cd57..91a7c07 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml index 6e2cdeb..56644ee 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml index c432e4f..2b8d205 100644 --- a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/gen.go b/internal/gen.go index 6e50fae..9cd35b3 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -181,6 +181,40 @@ func (q Query) ArgDictNode() *pyast.Node { } func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType { + // Parse the configuration + var conf Config + if len(req.PluginOptions) > 0 { + if err := json.Unmarshal(req.PluginOptions, &conf); err != nil { + log.Printf("failed to parse plugin options: %s", err) + } + } + + // Check for overrides + if len(conf.Overrides) > 0 && col.Table != nil { + tableName := col.Table.Name + if col.Table.Schema != "" && col.Table.Schema != req.Catalog.DefaultSchema { + tableName = col.Table.Schema + "." + tableName + } + + // Look for a matching override + for _, override := range conf.Overrides { + overrideKey := tableName + "." + col.Name + if override.Column == overrideKey { + // Found a match, use the override + typeStr := override.PyType + if override.PyImport != "" && !strings.Contains(typeStr, ".") { + typeStr = override.PyImport + "." + override.PyType + } + return pyType{ + InnerType: typeStr, + IsArray: col.IsArray, + IsNull: !col.NotNull, + } + } + } + } + + // No override found, use the standard type mapping typ := pyInnerType(req, col) return pyType{ InnerType: typ, diff --git a/internal/imports.go b/internal/imports.go index b88c58c..454eefd 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -97,6 +97,20 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS pkg := make(map[string]importSpec) + // Add custom imports from overrides + for _, override := range i.C.Overrides { + if override.PyImport != "" { + // Check if it's a standard module or a package import + if strings.Contains(override.PyImport, ".") { + // It's a package import + pkg[override.PyImport] = importSpec{Module: override.PyImport} + } else { + // It's a standard import + std[override.PyImport] = importSpec{Module: override.PyImport} + } + } + } + return std, pkg } @@ -167,6 +181,20 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map } } + // Add custom imports from overrides for query files + for _, override := range i.C.Overrides { + if override.PyImport != "" { + // Check if it's a standard module or a package import + if strings.Contains(override.PyImport, ".") { + // It's a package import + pkg[override.PyImport] = importSpec{Module: override.PyImport} + } else { + // It's a standard import + std[override.PyImport] = importSpec{Module: override.PyImport} + } + } + } + return std, pkg }