From 7734a233e6e758192fab9f57ed33235d56c39af4 Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Mon, 6 May 2019 10:57:50 -0400 Subject: [PATCH 1/4] Implement selective overriding mechanism --- graphene_sqlalchemy/converter.py | 123 ++++--- graphene_sqlalchemy/enums.py | 22 +- graphene_sqlalchemy/fields.py | 8 +- graphene_sqlalchemy/tests/conftest.py | 9 +- graphene_sqlalchemy/tests/models.py | 39 +- graphene_sqlalchemy/tests/test_converter.py | 194 ++++------ graphene_sqlalchemy/tests/test_fields.py | 19 +- graphene_sqlalchemy/tests/test_query.py | 114 ++++-- graphene_sqlalchemy/tests/test_schema.py | 50 --- graphene_sqlalchemy/tests/test_types.py | 380 +++++++++++++------- graphene_sqlalchemy/types.py | 137 ++++--- setup.cfg | 2 +- 12 files changed, 622 insertions(+), 475 deletions(-) delete mode 100644 graphene_sqlalchemy/tests/test_schema.py diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 9466cbaf..fc6d1b39 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -16,6 +16,10 @@ ChoiceType = JSONType = ScalarListType = TSVectorType = object +def _get_attr_resolver(attr_name): + return lambda root, _info: getattr(root, attr_name, None) + + def get_column_doc(column): return getattr(column, "doc", None) @@ -24,43 +28,61 @@ def is_column_nullable(column): return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory): - direction = relationship.direction - model = relationship.mapper.entity +def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, **field_kwargs): + direction = relationship_prop.direction + model = relationship_prop.mapper.entity def dynamic_type(): _type = registry.get_type_for_model(model) + if not _type: return None - if direction == interfaces.MANYTOONE or not relationship.uselist: - return Field(_type) + if direction == interfaces.MANYTOONE or not relationship_prop.uselist: + return Field( + _type, + resolver=_get_attr_resolver(relationship_prop.key), + **field_kwargs + ) elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): if _type._meta.connection: - return connection_field_factory(relationship, registry) - return Field(List(_type)) + # TODO Add a way to override connection_field_factory + return connection_field_factory(relationship_prop, registry, **field_kwargs) + return Field( + List(_type), + **field_kwargs + ) return Dynamic(dynamic_type) -def convert_sqlalchemy_hybrid_method(hybrid_item): - return String(description=getattr(hybrid_item, "__doc__", None), required=False) +def convert_sqlalchemy_hybrid_method(hybrid_prop, prop_name, **field_kwargs): + if 'type' not in field_kwargs: + # TODO The default type should be dependent on the type of the property propety. + field_kwargs['type'] = String + + return Field( + resolver=_get_attr_resolver(prop_name), + **field_kwargs + ) -def convert_sqlalchemy_composite(composite, registry): - converter = registry.get_converter_for_composite(composite.composite_class) +def convert_sqlalchemy_composite(composite_prop, registry): + converter = registry.get_converter_for_composite(composite_prop.composite_class) if not converter: try: raise Exception( "Don't know how to convert the composite field %s (%s)" - % (composite, composite.composite_class) + % (composite_prop, composite_prop.composite_class) ) except AttributeError: # handle fields that are not attached to a class yet (don't have a parent) raise Exception( "Don't know how to convert the composite field %r (%s)" - % (composite, composite.composite_class) + % (composite_prop, composite_prop.composite_class) ) - return converter(composite, registry) + + # TODO Add a way to override composite fields default parameters + return converter(composite_prop, registry) def _register_composite_class(cls, registry=None): @@ -78,8 +100,21 @@ def inner(fn): convert_sqlalchemy_composite.register = _register_composite_class -def convert_sqlalchemy_column(column, registry=None): - return convert_sqlalchemy_type(getattr(column, "type", None), column, registry) +def convert_sqlalchemy_column(column_prop, registry, **field_kwargs): + column = column_prop.columns[0] + if 'type' not in field_kwargs: + field_kwargs['type'] = convert_sqlalchemy_type(getattr(column, "type", None), column, registry) + + if 'required' not in field_kwargs: + field_kwargs['required'] = not is_column_nullable(column) + + if 'description' not in field_kwargs: + field_kwargs['description'] = get_column_doc(column) + + return Field( + resolver=_get_attr_resolver(column_prop.key), + **field_kwargs + ) @singledispatch @@ -101,93 +136,63 @@ def convert_sqlalchemy_type(type, column, registry=None): @convert_sqlalchemy_type.register(postgresql.CIDR) @convert_sqlalchemy_type.register(TSVectorType) def convert_column_to_string(type, column, registry=None): - return String( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return String @convert_sqlalchemy_type.register(types.DateTime) def convert_column_to_datetime(type, column, registry=None): from graphene.types.datetime import DateTime - - return DateTime( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return DateTime @convert_sqlalchemy_type.register(types.SmallInteger) @convert_sqlalchemy_type.register(types.Integer) def convert_column_to_int_or_id(type, column, registry=None): - if column.primary_key: - return ID( - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) - else: - return Int( - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) + return ID if column.primary_key else Int @convert_sqlalchemy_type.register(types.Boolean) def convert_column_to_boolean(type, column, registry=None): - return Boolean( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return Boolean @convert_sqlalchemy_type.register(types.Float) @convert_sqlalchemy_type.register(types.Numeric) @convert_sqlalchemy_type.register(types.BigInteger) def convert_column_to_float(type, column, registry=None): - return Float( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return Float @convert_sqlalchemy_type.register(types.Enum) def convert_enum_to_enum(type, column, registry=None): - return Field( - lambda: enum_for_sa_enum(type, registry or get_global_registry()), - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) + return lambda: enum_for_sa_enum(type, registry or get_global_registry()) +# TODO Make ChoiceType conversion consistent with other enums @convert_sqlalchemy_type.register(ChoiceType) def convert_choice_to_enum(type, column, registry=None): name = "{}_{}".format(column.table.name, column.name).upper() - return Enum(name, type.choices, description=get_column_doc(column)) + return Enum(name, type.choices) @convert_sqlalchemy_type.register(ScalarListType) def convert_scalar_list_to_list(type, column, registry=None): - return List(String, description=get_column_doc(column)) + return List(String) @convert_sqlalchemy_type.register(postgresql.ARRAY) def convert_postgres_array_to_list(_type, column, registry=None): - graphene_type = convert_sqlalchemy_type(column.type.item_type, column) - inner_type = type(graphene_type) - return List( - inner_type, - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) + inner_type = convert_sqlalchemy_type(column.type.item_type, column) + return List(inner_type) @convert_sqlalchemy_type.register(postgresql.HSTORE) @convert_sqlalchemy_type.register(postgresql.JSON) @convert_sqlalchemy_type.register(postgresql.JSONB) def convert_json_to_string(type, column, registry=None): - return JSONString( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return JSONString @convert_sqlalchemy_type.register(JSONType) def convert_json_type_to_string(type, column, registry=None): - return JSONString( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return JSONString diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index 6b84bf52..f100be19 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column +from sqlalchemy.orm import ColumnProperty from sqlalchemy.types import Enum as SQLAlchemyEnumType from graphene import Argument, Enum, List @@ -69,11 +69,12 @@ def enum_for_field(obj_type, field_name): orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) if orm_field is None: raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name)) - if not isinstance(orm_field, Column): + if not isinstance(orm_field, ColumnProperty): raise TypeError( "{}.{} does not map to model column".format(obj_type._meta.name, field_name) ) - sa_enum = orm_field.type + column = orm_field.columns[0] + sa_enum = column.type if not isinstance(sa_enum, SQLAlchemyEnumType): raise TypeError( "{}.{} does not map to enum column".format(obj_type._meta.name, field_name) @@ -138,15 +139,16 @@ def sort_enum_for_object_type( if only_fields and field_name not in only_fields: continue orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) - if not isinstance(orm_field, Column): + if not isinstance(orm_field, ColumnProperty): continue - if only_indexed and not (orm_field.primary_key or orm_field.index): + column = orm_field.columns[0] + if only_indexed and not (column.primary_key or column.index): continue - asc_name = get_name(orm_field.name, True) - asc_value = EnumValue(asc_name, orm_field.asc()) - desc_name = get_name(orm_field.name, False) - desc_value = EnumValue(desc_name, orm_field.desc()) - if orm_field.primary_key: + asc_name = get_name(column.name, True) + asc_value = EnumValue(asc_name, column.asc()) + desc_name = get_name(column.name, False) + desc_value = EnumValue(desc_name, column.desc()) + if column.primary_key: default.append(asc_value) members.extend(((asc_name, asc_value), (desc_name, desc_value))) enum = Enum(name, members) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 3ad15a92..e29d87be 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -97,22 +97,22 @@ def __init__(self, type, *args, **kwargs): super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) -def default_connection_field_factory(relationship, registry): +def default_connection_field_factory(relationship, registry, **field_kwargs): model = relationship.mapper.entity model_type = registry.get_type_for_model(model) - return createConnectionField(model_type) + return createConnectionField(model_type, **field_kwargs) # TODO Remove in next major version __connectionFactory = UnsortedSQLAlchemyConnectionField -def createConnectionField(_type): +def createConnectionField(_type, **field_kwargs): log.warning( 'createConnectionField is deprecated and will be removed in the next ' 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' ) - return __connectionFactory(_type) + return __connectionFactory(_type, **field_kwargs) def registerConnectionFieldFactory(factoryMethod): diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 2825eb3c..1515f2b7 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -2,8 +2,9 @@ from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker +from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry -from .models import Base +from .models import Base, CompositeFullName test_db_url = 'sqlite://' # use in-memory database for tests @@ -12,6 +13,12 @@ def reset_registry(): reset_global_registry() + # Prevent tests that implicitly depend on Reporter from raising + # Tests that explicitly depend on this behavior should re-register a converter + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + pass + @pytest.yield_fixture(scope="function") def session(): diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 12781cc5..1df28333 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -2,9 +2,11 @@ import enum -from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table +from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, + func, select) from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import mapper, relationship +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import column_property, composite, mapper, relationship PetKind = Enum("cat", "dog", name="pet_kind") @@ -39,22 +41,39 @@ class Pet(Base): reporter_id = Column(Integer(), ForeignKey("reporters.id")) +class CompositeFullName(object): + def __init__(self, first_name, last_name): + self.first_name = first_name + self.last_name = last_name + + def __composite_values__(self): + return self.first_name, self.last_name + + def __repr__(self): + return "{} {}".format(self.first_name, self.last_name) + + class Reporter(Base): __tablename__ = "reporters" + id = Column(Integer(), primary_key=True) - first_name = Column(String(30)) - last_name = Column(String(30)) - email = Column(String()) + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) pets = relationship("Pet", secondary=association_table, backref="reporters") articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) - # total = column_property( - # select([ - # func.cast(func.count(PersonInfo.id), Float) - # ]) - # ) + @hybrid_property + def hybrid_prop(self): + return self.first_name + + column_prop = column_property( + select([func.cast(func.count(id), Integer)]), doc="Column property" + ) + + composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") class Article(Base): diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index f38999d2..f255350d 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,11 +1,11 @@ import enum import pytest -from sqlalchemy import Column, Table, case, func, select, types +from sqlalchemy import Column, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from sqlalchemy.sql.elements import Label from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType import graphene @@ -18,90 +18,79 @@ convert_sqlalchemy_relationship) from ..fields import (UnsortedSQLAlchemyConnectionField, default_connection_field_factory) -from ..registry import Registry +from ..registry import Registry, get_global_registry from ..types import SQLAlchemyObjectType -from .models import Article, Pet, Reporter +from .models import Article, CompositeFullName, Pet, Reporter -def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): - column = Column(sqlalchemy_type, doc="Custom Help Text", **kwargs) - graphene_type = convert_sqlalchemy_column(column) - assert isinstance(graphene_type, graphene_field) - field = ( - graphene_type - if isinstance(graphene_type, graphene.Field) - else graphene_type.Field() - ) - assert field.description == "Custom Help Text" - return field +def get_field(sqlalchemy_type, **column_kwargs): + class Model(declarative_base()): + __tablename__ = 'model' + id_ = Column(types.Integer, primary_key=True) + column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs) + column_prop = inspect(Model).column_attrs['column'] + return convert_sqlalchemy_column(column_prop, get_global_registry()) -def assert_composite_conversion( - composite_class, composite_columns, graphene_field, registry, **kwargs -): - composite_column = composite( - composite_class, *composite_columns, doc="Custom Help Text", **kwargs - ) - graphene_type = convert_sqlalchemy_composite(composite_column, registry) - assert isinstance(graphene_type, graphene_field) - field = graphene_type.Field() - # SQLAlchemy currently does not persist the doc onto the column, even though - # the documentation says it does.... - # assert field.description == 'Custom Help Text' - return field + +def get_field_from_column(column_): + class Model(declarative_base()): + __tablename__ = 'model' + id_ = Column(types.Integer, primary_key=True) + column = column_ + + column_prop = inspect(Model).column_attrs['column'] + return convert_sqlalchemy_column(column_prop, get_global_registry()) def test_should_unknown_sqlalchemy_field_raise_exception(): re_err = "Don't know how to convert the SQLAlchemy field" with pytest.raises(Exception, match=re_err): - convert_sqlalchemy_column(None) + get_field(types.Binary()) def test_should_date_convert_string(): - assert_column_conversion(types.Date(), graphene.String) + assert get_field(types.Date()).type == graphene.String -def test_should_datetime_convert_string(): - assert_column_conversion(types.DateTime(), DateTime) +def test_should_datetime_convert_datetime(): + assert get_field(types.DateTime()).type == DateTime def test_should_time_convert_string(): - assert_column_conversion(types.Time(), graphene.String) + assert get_field(types.Time()).type == graphene.String def test_should_string_convert_string(): - assert_column_conversion(types.String(), graphene.String) + assert get_field(types.String()).type == graphene.String def test_should_text_convert_string(): - assert_column_conversion(types.Text(), graphene.String) + assert get_field(types.Text()).type == graphene.String def test_should_unicode_convert_string(): - assert_column_conversion(types.Unicode(), graphene.String) + assert get_field(types.Unicode()).type == graphene.String def test_should_unicodetext_convert_string(): - assert_column_conversion(types.UnicodeText(), graphene.String) + assert get_field(types.UnicodeText()).type == graphene.String def test_should_enum_convert_enum(): - field = assert_column_conversion( - types.Enum(enum.Enum("TwoNumbers", ("one", "two"))), graphene.Field - ) + field = get_field(types.Enum(enum.Enum("TwoNumbers", ("one", "two")))) field_type = field.type() assert isinstance(field_type, graphene.Enum) + assert field_type._meta.name == "TwoNumbers" assert hasattr(field_type, "ONE") assert not hasattr(field_type, "one") assert hasattr(field_type, "TWO") assert not hasattr(field_type, "two") - field = assert_column_conversion( - types.Enum("one", "two", name="two_numbers"), graphene.Field - ) + field = get_field(types.Enum("one", "two", name="two_numbers")) field_type = field.type() - assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) + assert field_type._meta.name == "TwoNumbers" assert hasattr(field_type, "ONE") assert not hasattr(field_type, "one") assert hasattr(field_type, "TWO") @@ -109,89 +98,65 @@ def test_should_enum_convert_enum(): def test_should_not_enum_convert_enum_without_name(): - field = assert_column_conversion( - types.Enum("one", "two"), graphene.Field - ) + field = get_field(types.Enum("one", "two")) re_err = r"No type name specified for Enum\('one', 'two'\)" with pytest.raises(TypeError, match=re_err): field.type() def test_should_small_integer_convert_int(): - assert_column_conversion(types.SmallInteger(), graphene.Int) + assert get_field(types.SmallInteger()).type == graphene.Int def test_should_big_integer_convert_int(): - assert_column_conversion(types.BigInteger(), graphene.Float) + assert get_field(types.BigInteger()).type == graphene.Float def test_should_integer_convert_int(): - assert_column_conversion(types.Integer(), graphene.Int) + assert get_field(types.Integer()).type == graphene.Int -def test_should_integer_convert_id(): - assert_column_conversion(types.Integer(), graphene.ID, primary_key=True) +def test_should_primary_integer_convert_id(): + assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull(graphene.ID) def test_should_boolean_convert_boolean(): - assert_column_conversion(types.Boolean(), graphene.Boolean) + assert get_field(types.Boolean()).type == graphene.Boolean def test_should_float_convert_float(): - assert_column_conversion(types.Float(), graphene.Float) + assert get_field(types.Float()).type == graphene.Float def test_should_numeric_convert_float(): - assert_column_conversion(types.Numeric(), graphene.Float) - - -def test_should_label_convert_string(): - label = Label("label_test", case([], else_="foo"), type_=types.Unicode()) - graphene_type = convert_sqlalchemy_column(label) - assert isinstance(graphene_type, graphene.String) - - -def test_should_label_convert_int(): - label = Label("int_label_test", case([], else_="foo"), type_=types.Integer()) - graphene_type = convert_sqlalchemy_column(label) - assert isinstance(graphene_type, graphene.Int) + assert get_field(types.Numeric()).type == graphene.Float def test_should_choice_convert_enum(): - TYPES = [(u"es", u"Spanish"), (u"en", u"English")] - column = Column(ChoiceType(TYPES), doc="Language", name="language") - Base = declarative_base() - - Table("translatedmodel", Base.metadata, column) - graphene_type = convert_sqlalchemy_column(column) + field = get_field(ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) + graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) - assert graphene_type._meta.name == "TRANSLATEDMODEL_LANGUAGE" - assert graphene_type._meta.description == "Language" + assert graphene_type._meta.name == "MODEL_COLUMN" assert graphene_type._meta.enum.__members__["es"].value == "Spanish" assert graphene_type._meta.enum.__members__["en"].value == "English" def test_should_columproperty_convert(): + field = get_field_from_column(column_property( + select([func.sum(func.cast(id, types.Integer))]).where(id == 1) + )) - Base = declarative_base() - - class Test(Base): - __tablename__ = "test" - id = Column(types.Integer, primary_key=True) - column = column_property( - select([func.sum(func.cast(id, types.Integer))]).where(id == 1) - ) - - graphene_type = convert_sqlalchemy_column(Test.column) - assert not graphene_type.kwargs["required"] + assert field.type == graphene.Int def test_should_scalar_list_convert_list(): - assert_column_conversion(ScalarListType(), graphene.List) + field = get_field(ScalarListType()) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.String def test_should_jsontype_convert_jsonstring(): - assert_column_conversion(JSONType(), JSONString) + assert get_field(JSONType()).type == JSONString def test_should_manytomany_convert_connectionorlist(): @@ -287,16 +252,14 @@ class Meta: def test_should_postgresql_uuid_convert(): - assert_column_conversion(postgresql.UUID(), graphene.String) + assert get_field(postgresql.UUID()).type == graphene.String def test_should_postgresql_enum_convert(): - field = assert_column_conversion( - postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field - ) + field = get_field(postgresql.ENUM("one", "two", name="two_numbers")) field_type = field.type() - assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) + assert field_type._meta.name == "TwoNumbers" assert hasattr(field_type, "ONE") assert not hasattr(field_type, "one") assert hasattr(field_type, "TWO") @@ -304,10 +267,7 @@ def test_should_postgresql_enum_convert(): def test_should_postgresql_py_enum_convert(): - field = assert_column_conversion( - postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), - graphene.Field, - ) + field = get_field(postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers")) field_type = field.type() assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) @@ -318,55 +278,51 @@ def test_should_postgresql_py_enum_convert(): def test_should_postgresql_array_convert(): - assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List) + field = get_field(postgresql.ARRAY(types.Integer)) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.Int def test_should_postgresql_json_convert(): - assert_column_conversion(postgresql.JSON(), JSONString) + assert get_field(postgresql.JSON()).type == graphene.JSONString def test_should_postgresql_jsonb_convert(): - assert_column_conversion(postgresql.JSONB(), JSONString) + assert get_field(postgresql.JSONB()).type == graphene.JSONString def test_should_postgresql_hstore_convert(): - assert_column_conversion(postgresql.HSTORE(), JSONString) + assert get_field(postgresql.HSTORE()).type == graphene.JSONString def test_should_composite_convert(): + registry = Registry() + class CompositeClass: def __init__(self, col1, col2): self.col1 = col1 self.col2 = col2 - registry = Registry() - @convert_sqlalchemy_composite.register(CompositeClass, registry) def convert_composite_class(composite, registry): return graphene.String(description=composite.doc) - assert_composite_conversion( - CompositeClass, - (Column(types.Unicode(50)), Column(types.Unicode(50))), - graphene.String, + field = convert_sqlalchemy_composite( + composite(CompositeClass, (Column(types.Unicode(50)), Column(types.Unicode(50))), doc="Custom Help Text"), registry, ) + assert isinstance(field, graphene.String) def test_should_unknown_sqlalchemy_composite_raise_exception(): - registry = Registry() + class CompositeClass: + def __init__(self, col1, col2): + self.col1 = col1 + self.col2 = col2 re_err = "Don't know how to convert the composite field" with pytest.raises(Exception, match=re_err): - - class CompositeClass(object): - def __init__(self, col1, col2): - self.col1 = col1 - self.col2 = col2 - - assert_composite_conversion( - CompositeClass, - (Column(types.Unicode(50)), Column(types.Unicode(50))), - graphene.String, - registry, + convert_sqlalchemy_composite( + composite(CompositeFullName, (Column(types.Unicode(50)), Column(types.Unicode(50)))), + Registry(), ) diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 0f8738f0..875b729d 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,4 +1,5 @@ import pytest +from promise import Promise from graphene.relay import Connection @@ -18,24 +19,34 @@ class Meta: model = EditorModel -class PetConn(Connection): +class PetConnection(Connection): class Meta: node = Pet +def test_promise_connection_resolver(): + def resolver(_obj, _info): + return Promise.resolve([]) + + result = SQLAlchemyConnectionField.connection_resolver( + resolver, PetConnection, Pet, None, None + ) + assert isinstance(result, Promise) + + def test_sort_added_by_default(): - field = SQLAlchemyConnectionField(PetConn) + field = SQLAlchemyConnectionField(PetConnection) assert "sort" in field.args assert field.args["sort"] == Pet.sort_argument() def test_sort_can_be_removed(): - field = SQLAlchemyConnectionField(PetConn, sort=None) + field = SQLAlchemyConnectionField(PetConnection, sort=None) assert "sort" not in field.args def test_custom_sort(): - field = SQLAlchemyConnectionField(PetConn, sort=Editor.sort_argument()) + field = SQLAlchemyConnectionField(PetConnection, sort=Editor.sort_argument()) assert field.args["sort"] == Editor.sort_argument() diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 5279bd87..269b9bd7 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,9 +1,10 @@ import graphene from graphene.relay import Connection, Node +from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField -from ..types import SQLAlchemyObjectType -from .models import Article, Editor, HairKind, Pet, Reporter +from ..types import ORMField, SQLAlchemyObjectType +from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter def to_std_dicts(value): @@ -37,9 +38,13 @@ def add_test_data(session): session.commit() -def test_should_query_well(session): +def test_query_fields(session): add_test_data(session) + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return graphene.String() + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -55,11 +60,12 @@ def resolve_reporters(self, _info): return session.query(Reporter) query = """ - query ReporterQuery { + query { reporter { firstName - lastName - email + columnProp + hybridProp + compositeProp } reporters { firstName @@ -67,7 +73,12 @@ def resolve_reporters(self, _info): } """ expected = { - "reporter": {"firstName": "John", "lastName": "Doe", "email": None}, + "reporter": { + "firstName": "John", + "hybridProp": "John", + "columnProp": 2, + "compositeProp": "John Doe", + }, "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], } schema = graphene.Schema(query=Query) @@ -77,7 +88,7 @@ def resolve_reporters(self, _info): assert result == expected -def test_should_query_node(session): +def test_query_node(session): add_test_data(session) class ReporterNode(SQLAlchemyObjectType): @@ -101,20 +112,16 @@ class Meta: class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) - article = graphene.Field(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleConnection) def resolve_reporter(self, _info): return session.query(Reporter).first() - def resolve_article(self, _info): - return session.query(Article).first() - query = """ - query ReporterQuery { + query { reporter { id - firstName, + firstName articles { edges { node { @@ -122,8 +129,6 @@ def resolve_article(self, _info): } } } - lastName, - email } allArticles { edges { @@ -147,8 +152,6 @@ def resolve_article(self, _info): "reporter": { "id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "John", - "lastName": "Doe", - "email": None, "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, }, "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, @@ -161,7 +164,74 @@ def resolve_article(self, _info): assert result == expected -def test_should_custom_identifier(session): +def test_orm_field(session): + add_test_data(session) + + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return graphene.String() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + first_name_v2 = ORMField(prop_name='first_name') + hybrid_prop_v2 = ORMField(prop_name='hybrid_prop') + column_prop_v2 = ORMField(prop_name='column_prop') + composite_prop = ORMField() + favorite_article_v2 = ORMField(prop_name='favorite_article') + articles_v2 = ORMField(prop_name='articles') + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + query = """ + query { + reporter { + firstNameV2 + hybridPropV2 + columnPropV2 + compositeProp + favoriteArticleV2 { + headline + } + articlesV2(first: 1) { + edges { + node { + headline + } + } + } + } + } + """ + expected = { + "reporter": { + "firstNameV2": "John", + "hybridPropV2": "John", + "columnPropV2": 2, + "compositeProp": "John Doe", + "favoriteArticleV2": {"headline": "Hi!"}, + "articlesV2": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +def test_custom_identifier(session): add_test_data(session) class EditorNode(SQLAlchemyObjectType): @@ -178,7 +248,7 @@ class Query(graphene.ObjectType): all_editors = SQLAlchemyConnectionField(EditorConnection) query = """ - query EditorQuery { + query { allEditors { edges { node { @@ -206,7 +276,7 @@ class Query(graphene.ObjectType): assert result == expected -def test_should_mutate_well(session): +def test_mutation(session): add_test_data(session) class EditorNode(SQLAlchemyObjectType): @@ -252,7 +322,7 @@ class Mutation(graphene.ObjectType): create_article = CreateArticle.Field() query = """ - mutation ArticleCreator { + mutation { createArticle( headline: "My Article" reporterId: "1" diff --git a/graphene_sqlalchemy/tests/test_schema.py b/graphene_sqlalchemy/tests/test_schema.py deleted file mode 100644 index 87739bdb..00000000 --- a/graphene_sqlalchemy/tests/test_schema.py +++ /dev/null @@ -1,50 +0,0 @@ -from py.test import raises - -from ..registry import Registry -from ..types import SQLAlchemyObjectType -from .models import Reporter - - -def test_should_raise_if_no_model(): - with raises(Exception) as excinfo: - - class Character1(SQLAlchemyObjectType): - pass - - assert "valid SQLAlchemy Model" in str(excinfo.value) - - -def test_should_raise_if_model_is_invalid(): - with raises(Exception) as excinfo: - - class Character2(SQLAlchemyObjectType): - class Meta: - model = 1 - - assert "valid SQLAlchemy Model" in str(excinfo.value) - - -def test_should_map_fields_correctly(): - class ReporterType2(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = Registry() - - assert list(ReporterType2._meta.fields.keys()) == [ - "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - "pets", - "articles", - "favorite_article", - ] - - -def test_should_map_only_few_fields(): - class Reporter2(SQLAlchemyObjectType): - class Meta: - model = Reporter - only_fields = ("id", "email") - assert list(Reporter2._meta.fields.keys()) == ["id", "email"] diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index b76136fb..1afcb10c 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,196 +1,316 @@ -from collections import OrderedDict - +import mock +import pytest import six # noqa F401 -from promise import Promise -from graphene import (Connection, Field, Int, Interface, Node, ObjectType, - is_node) +from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, + ObjectType, String) +from ..converter import convert_sqlalchemy_composite from ..fields import (SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField, registerConnectionFieldFactory, unregisterConnectionFieldFactory) -from ..registry import Registry -from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions -from .models import Article, Reporter - -registry = Registry() +from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions +from .models import Article, CompositeFullName, Pet, Reporter -class Character(SQLAlchemyObjectType): - """Character description""" +def test_should_raise_if_no_model(): + re_err = r"valid SQLAlchemy Model" + with pytest.raises(Exception, match=re_err): + class Character1(SQLAlchemyObjectType): + pass - class Meta: - model = Reporter - registry = registry +def test_should_raise_if_model_is_invalid(): + re_err = r"valid SQLAlchemy Model" + with pytest.raises(Exception, match=re_err): + class Character(SQLAlchemyObjectType): + class Meta: + model = 1 -class Human(SQLAlchemyObjectType): - """Human description""" - pub_date = Int() +def test_sqlalchemy_node(session): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) - class Meta: - model = Article - exclude_fields = ("id",) - registry = registry - interfaces = (Node,) + reporter_id_field = ReporterType._meta.fields["id"] + assert isinstance(reporter_id_field, GlobalID) + reporter = Reporter() + session.add(reporter) + session.commit() + info = mock.Mock(context={'session': session}) + reporter_node = ReporterType.get_node(info, reporter.id) + assert reporter == reporter_node -def test_sqlalchemy_interface(): - assert issubclass(Node, Interface) - assert issubclass(Node, Node) +def test_sqlalchemy_default_fields(): + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return String() -# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) -# def test_sqlalchemy_get_node(get): -# human = Human.get_node(1, None) -# get.assert_called_with(id=1) -# assert human.id == 1 + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) -def test_objecttype_registered(): - assert issubclass(Character, ObjectType) - assert Character._meta.model == Reporter - assert list(Character._meta.fields.keys()) == [ + assert list(ReporterType._meta.fields.keys()) == [ + "column_prop", "id", "first_name", "last_name", "email", "favorite_pet_kind", + "composite_prop", + "hybrid_prop", "pets", "articles", "favorite_article", ] + # column + first_name_field = ReporterType._meta.fields['first_name'] + assert first_name_field.type == String + assert first_name_field.description == "First name" + + # column_property + column_prop_field = ReporterType._meta.fields['column_prop'] + assert column_prop_field.type == Int + # "doc" is ignored by column_property + assert column_prop_field.description is None + + # composite + full_name_field = ReporterType._meta.fields['composite_prop'] + assert full_name_field.type == String + # "doc" is ignored by composite + assert full_name_field.description is None + + # hybrid_property + hybrid_prop = ReporterType._meta.fields['hybrid_prop'] + assert hybrid_prop.type == String + # "doc" is ignored by hybrid_property + assert hybrid_prop.description is None + + # relationship + favorite_article_field = ReporterType._meta.fields['favorite_article'] + assert isinstance(favorite_article_field, Dynamic) + assert favorite_article_field.type().type == ArticleType + assert favorite_article_field.type().description is None + + +def test_sqlalchemy_override_fields(): + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return String() -# def test_sqlalchemynode_idfield(): -# idfield = Node._meta.fields_map['id'] -# assert isinstance(idfield, GlobalIDField) - - -# def test_node_idfield(): -# idfield = Human._meta.fields_map['id'] -# assert isinstance(idfield, GlobalIDField) + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + # columns + first_name = ORMField(required=True) + last_name = ORMField(description='Overridden') + email = ORMField(deprecation_reason='Overridden') + email_v2 = ORMField(prop_name='email', type=Int) -def test_node_replacedfield(): - idfield = Human._meta.fields["pub_date"] - assert isinstance(idfield, Field) - assert idfield.type == Int + # column_property + column_prop = ORMField(type=String) + # composite + composite_prop = ORMField() -def test_object_type(): - class Human(SQLAlchemyObjectType): - """Human description""" + # hybrid_property + hybrid_prop = ORMField(description='Overridden') - pub_date = Int() + # relationships + favorite_article = ORMField(description='Overridden') + articles = ORMField(deprecation_reason='Overridden') + pets = ORMField(description='Overridden') + class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - # exclude_fields = ('id', ) - registry = registry interfaces = (Node,) - assert issubclass(Human, ObjectType) - assert list(Human._meta.fields.keys()) == [ + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + use_connection = False + + assert list(ReporterType._meta.fields.keys()) == [ + "first_name", + "last_name", + "email", + "email_v2", + "column_prop", + "composite_prop", + "hybrid_prop", + "favorite_article", + "articles", + "pets", "id", - "headline", - "pub_date", - "reporter_id", - "reporter", + "favorite_pet_kind", ] - assert is_node(Human) + first_name_field = ReporterType._meta.fields['first_name'] + assert isinstance(first_name_field.type, NonNull) + assert first_name_field.type.of_type == String + assert first_name_field.description == "First name" + assert first_name_field.deprecation_reason is None + + last_name_field = ReporterType._meta.fields['last_name'] + assert last_name_field.type == String + assert last_name_field.description == "Overridden" + assert last_name_field.deprecation_reason is None + + email_field = ReporterType._meta.fields['email'] + assert email_field.type == String + assert email_field.description == "Email" + assert email_field.deprecation_reason == "Overridden" + + email_field_v2 = ReporterType._meta.fields['email_v2'] + assert email_field_v2.type == Int + assert email_field_v2.description == "Email" + assert email_field_v2.deprecation_reason is None + + hybrid_prop_field = ReporterType._meta.fields['hybrid_prop'] + assert hybrid_prop_field.type == String + assert hybrid_prop_field.description == "Overridden" + assert hybrid_prop_field.deprecation_reason is None + + column_prop_field_v2 = ReporterType._meta.fields['column_prop'] + assert column_prop_field_v2.type == String + assert column_prop_field_v2.description is None + assert column_prop_field_v2.deprecation_reason is None + + composite_prop_field = ReporterType._meta.fields['composite_prop'] + assert composite_prop_field.type == String + assert composite_prop_field.description is None + assert composite_prop_field.deprecation_reason is None + + favorite_article_field = ReporterType._meta.fields['favorite_article'] + assert isinstance(favorite_article_field, Dynamic) + assert favorite_article_field.type().type == ArticleType + assert favorite_article_field.type().description == 'Overridden' + + articles_field = ReporterType._meta.fields['articles'] + assert isinstance(articles_field, Dynamic) + assert isinstance(articles_field.type(), UnsortedSQLAlchemyConnectionField) + assert articles_field.type().deprecation_reason == "Overridden" + + pets_field = ReporterType._meta.fields['pets'] + assert isinstance(pets_field, Dynamic) + assert isinstance(pets_field.type().type, List) + assert pets_field.type().type.of_type == PetType + assert pets_field.type().description == 'Overridden' + + +def test_only_fields(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + only_fields = ("id", "last_name") -# Test Custom SQLAlchemyObjectType Implementation -class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): - class Meta: - abstract = True + first_name = ORMField() # Takes precedence + last_name = ORMField() # Noop + assert list(ReporterType._meta.fields.keys()) == ["first_name", "last_name", "id"] -class CustomCharacter(CustomSQLAlchemyObjectType): - """Character description""" - class Meta: - model = Reporter - registry = registry +def test_exclude_fields(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + exclude_fields = ("id", "first_name") + first_name = ORMField() # Takes precedence + last_name = ORMField() # Noop -def test_custom_objecttype_registered(): - assert issubclass(CustomCharacter, ObjectType) - assert CustomCharacter._meta.model == Reporter - assert list(CustomCharacter._meta.fields.keys()) == [ - "id", + assert list(ReporterType._meta.fields.keys()) == [ "first_name", "last_name", + "column_prop", "email", "favorite_pet_kind", + "hybrid_prop", "pets", "articles", "favorite_article", ] -# Test Custom SQLAlchemyObjectType with Custom Options -class CustomOptions(SQLAlchemyObjectTypeOptions): - custom_option = None - custom_fields = None +def test_only_and_exclude_fields(): + re_err = r"'only_fields' and 'exclude_fields' cannot be both set" + with pytest.raises(Exception, match=re_err): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + only_fields = ("id", "last_name") + exclude_fields = ("id", "last_name") -class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): - class Meta: - abstract = True +def test_sqlalchemy_redefine_field(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter - @classmethod - def __init_subclass_with_meta__( - cls, custom_option=None, custom_fields=None, **options - ): - _meta = CustomOptions(cls) - _meta.custom_option = custom_option - _meta.fields = custom_fields - super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) + first_name = Int() + first_name_field = ReporterType._meta.fields["first_name"] + assert isinstance(first_name_field, Field) + assert first_name_field.type == Int -class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): - class Meta: - model = Reporter - custom_option = "custom_option" - custom_fields = OrderedDict([("custom_field", Field(Int()))]) +# Test Custom SQLAlchemyObjectType Implementation -def test_objecttype_with_custom_options(): - assert issubclass(ReporterWithCustomOptions, ObjectType) - assert ReporterWithCustomOptions._meta.model == Reporter - assert list(ReporterWithCustomOptions._meta.fields.keys()) == [ - "custom_field", - "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - "pets", - "articles", - "favorite_article", - ] - assert ReporterWithCustomOptions._meta.custom_option == "custom_option" - assert isinstance(ReporterWithCustomOptions._meta.fields["custom_field"].type, Int) +def test_custom_objecttype_registered(): + class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): + class Meta: + abstract = True + + class CustomReporterType(CustomSQLAlchemyObjectType): + class Meta: + model = Reporter + assert issubclass(CustomReporterType, ObjectType) + assert CustomReporterType._meta.model == Reporter + assert len(CustomReporterType._meta.fields) == 10 -def test_promise_connection_resolver(): - class TestConnection(Connection): + +# Test Custom SQLAlchemyObjectType with Custom Options +def test_objecttype_with_custom_options(): + class CustomOptions(SQLAlchemyObjectTypeOptions): + custom_option = None + + class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): class Meta: - node = ReporterWithCustomOptions + abstract = True - def resolver(_obj, _info): - return Promise.resolve([]) + @classmethod + def __init_subclass_with_meta__(cls, custom_option=None, **options): + _meta = CustomOptions(cls) + _meta.custom_option = custom_option + super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) - result = SQLAlchemyConnectionField.connection_resolver( - resolver, TestConnection, ReporterWithCustomOptions, None, None - ) - assert result is not None + class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): + class Meta: + model = Reporter + custom_option = "custom_option" + + assert issubclass(ReporterWithCustomOptions, ObjectType) + assert ReporterWithCustomOptions._meta.model == Reporter + assert ReporterWithCustomOptions._meta.custom_option == "custom_option" # Tests for connection_field_factory @@ -200,42 +320,34 @@ class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField): def test_default_connection_field_factory(): - _registry = Registry() - class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - registry = _registry interfaces = (Node,) class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - registry = _registry interfaces = (Node,) assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) -def test_register_connection_field_factory(): +def test_custom_connection_field_factory(): def test_connection_field_factory(relationship, registry): model = relationship.mapper.entity _type = registry.get_type_for_model(model) return _TestSQLAlchemyConnectionField(_type._meta.connection) - _registry = Registry() - class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - registry = _registry interfaces = (Node,) connection_field_factory = test_connection_field_factory class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - registry = _registry interfaces = (Node,) assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) @@ -244,18 +356,14 @@ class Meta: def test_deprecated_registerConnectionFieldFactory(): registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) - _registry = Registry() - class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - registry = _registry interfaces = (Node,) class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - registry = _registry interfaces = (Node,) assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) @@ -265,18 +373,14 @@ def test_deprecated_unregisterConnectionFieldFactory(): registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) unregisterConnectionFieldFactory() - _registry = Registry() - class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - registry = _registry interfaces = (Node,) class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - registry = _registry interfaces = (Node,) assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index c20e8cfc..a400cd12 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -3,6 +3,8 @@ import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect as sqlalchemyinspect +from sqlalchemy.orm import (ColumnProperty, CompositeProperty, + RelationshipProperty) from sqlalchemy.orm.exc import NoResultFound from graphene import Field @@ -21,70 +23,88 @@ from .utils import get_query, is_mapped_class, is_mapped_instance +class ORMField(object): + def __init__( + self, + type=None, + prop_name=None, + description=None, + deprecation_reason=None, + required=None, + **field_kwargs + ): + # The is only useful for documentation and auto-completion + common_kwargs = { + 'type': type, + 'prop_name': prop_name, + 'description': description, + 'deprecation_reason': deprecation_reason, + 'required': required, + } + common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} + self.kwargs = field_kwargs + self.kwargs.update(common_kwargs) + + def construct_fields( obj_type, model, registry, only_fields, exclude_fields, connection_field_factory ): inspected_model = sqlalchemyinspect(model) + all_model_props = OrderedDict( + inspected_model.column_attrs.items() + + inspected_model.composites.items() + + [(name, item) for name, item in inspected_model.all_orm_descriptors.items() + if isinstance(item, hybrid_property)] + + inspected_model.relationships.items() + ) + + auto_orm_field_names = [] + for prop_name, prop in all_model_props.items(): + if (only_fields and prop_name not in only_fields) or (prop_name in exclude_fields): + continue + auto_orm_field_names.append(prop_name) + + # TODO Get ORMField fields defined on parent classes + custom_orm_fields = OrderedDict() + for attname, value in list(obj_type.__dict__.items()): + if isinstance(value, ORMField): + custom_orm_fields[attname] = value + + for orm_field_name, orm_field in custom_orm_fields.items(): + prop_name = orm_field.kwargs.get('prop_name', orm_field_name) + if prop_name not in all_model_props: + raise Exception('Cannot map ORMField "{}" to SQLAlchemy model property'.format(orm_field_name)) + orm_field.kwargs['prop_name'] = prop_name + + orm_fields = custom_orm_fields.copy() + for orm_field_name in auto_orm_field_names: + if orm_field_name in orm_fields: + continue + orm_fields[orm_field_name] = ORMField(prop_name=orm_field_name) fields = OrderedDict() + for orm_field_name, orm_field in orm_fields.items(): + prop_name = orm_field.kwargs.pop('prop_name') + prop = all_model_props[prop_name] + + if isinstance(prop, ColumnProperty): + field = convert_sqlalchemy_column(prop, registry, **orm_field.kwargs) + elif isinstance(prop, RelationshipProperty): + field = convert_sqlalchemy_relationship(prop, registry, connection_field_factory, **orm_field.kwargs) + elif isinstance(prop, CompositeProperty): + if prop_name != orm_field_name or orm_field.kwargs: + # TODO Add a way to override composite property fields + raise ValueError( + "ORMField kwargs for composite fields must be empty. " + "Field: {}.{}".format(obj_type.__name__, orm_field_name)) + field = convert_sqlalchemy_composite(prop, registry) + elif isinstance(prop, hybrid_property): + field = convert_sqlalchemy_hybrid_method(prop, prop_name, **orm_field.kwargs) + else: + raise Exception('Property type is not supported') # Should never happen - for name, column in inspected_model.columns.items(): - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - converted_column = convert_sqlalchemy_column(column, registry) - registry.register_orm_field(obj_type, name, column) - fields[name] = converted_column - - for name, composite in inspected_model.composites.items(): - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - converted_composite = convert_sqlalchemy_composite(composite, registry) - registry.register_orm_field(obj_type, name, composite) - fields[name] = converted_composite - - for hybrid_item in inspected_model.all_orm_descriptors: - - if type(hybrid_item) == hybrid_property: - name = hybrid_item.__name__ - - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - - converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item) - registry.register_orm_field(obj_type, name, hybrid_item) - fields[name] = converted_hybrid_property - - # Get all the columns for the relationships on the model - for relationship in inspected_model.relationships: - is_not_in_only = only_fields and relationship.key not in only_fields - # is_already_created = relationship.key in options.fields - is_excluded = relationship.key in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - converted_relationship = convert_sqlalchemy_relationship( - relationship, registry, connection_field_factory - ) - name = relationship.key - registry.register_orm_field(obj_type, name, relationship) - fields[name] = converted_relationship + registry.register_orm_field(obj_type, orm_field_name, prop) + fields[orm_field_name] = field return fields @@ -126,6 +146,9 @@ def __init_subclass_with_meta__( 'Registry, received "{}".' ).format(cls.__name__, registry) + if only_fields and exclude_fields: + raise ValueError("The options 'only_fields' and 'exclude_fields' cannot be both set on the same type.") + sqla_fields = yank_fields_from_attrs( construct_fields( obj_type=cls, diff --git a/setup.cfg b/setup.cfg index 39a48fd2..8b51a676 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ max-line-length = 120 [isort] known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=database,flask,models,nameko,promise,py,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=database,flask,mock,models,nameko,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER no_lines_before=FIRSTPARTY From 1e3817f81045442648f1c13129df0e33a40d34d8 Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Wed, 15 May 2019 11:29:22 -0400 Subject: [PATCH 2/4] fix field ordering in python version < 3.5 --- graphene_sqlalchemy/tests/test_types.py | 8 +++++++- graphene_sqlalchemy/types.py | 15 ++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 1afcb10c..954e5026 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -62,14 +62,18 @@ class Meta: interfaces = (Node,) assert list(ReporterType._meta.fields.keys()) == [ - "column_prop", + # Columns + "column_prop", # SQLAlchemy retuns column properties first "id", "first_name", "last_name", "email", "favorite_pet_kind", + # Composite "composite_prop", + # Hybrid "hybrid_prop", + # Relationship "pets", "articles", "favorite_article", @@ -147,6 +151,7 @@ class Meta: use_connection = False assert list(ReporterType._meta.fields.keys()) == [ + # First the ORMField in the order they were defined "first_name", "last_name", "email", @@ -157,6 +162,7 @@ class Meta: "favorite_article", "articles", "pets", + # Then the automatic SQLAlchemy fields "id", "favorite_pet_kind", ] diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index a400cd12..f64d3ba9 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -11,6 +11,7 @@ from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs +from graphene.utils.orderedtype import OrderedType from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -23,7 +24,7 @@ from .utils import get_query, is_mapped_class, is_mapped_instance -class ORMField(object): +class ORMField(OrderedType): def __init__( self, type=None, @@ -31,8 +32,10 @@ def __init__( description=None, deprecation_reason=None, required=None, + _creation_counter=None, **field_kwargs ): + super(ORMField, self).__init__(_creation_counter=_creation_counter) # The is only useful for documentation and auto-completion common_kwargs = { 'type': type, @@ -65,18 +68,19 @@ def construct_fields( auto_orm_field_names.append(prop_name) # TODO Get ORMField fields defined on parent classes - custom_orm_fields = OrderedDict() + custom_orm_fields_items = [] for attname, value in list(obj_type.__dict__.items()): if isinstance(value, ORMField): - custom_orm_fields[attname] = value + custom_orm_fields_items.append((attname, value)) + custom_orm_fields_items = sorted(custom_orm_fields_items, key=lambda item: item[1]) - for orm_field_name, orm_field in custom_orm_fields.items(): + for orm_field_name, orm_field in custom_orm_fields_items: prop_name = orm_field.kwargs.get('prop_name', orm_field_name) if prop_name not in all_model_props: raise Exception('Cannot map ORMField "{}" to SQLAlchemy model property'.format(orm_field_name)) orm_field.kwargs['prop_name'] = prop_name - orm_fields = custom_orm_fields.copy() + orm_fields = OrderedDict(custom_orm_fields_items) for orm_field_name in auto_orm_field_names: if orm_field_name in orm_fields: continue @@ -159,6 +163,7 @@ def __init_subclass_with_meta__( connection_field_factory=connection_field_factory, ), _as=Field, + sort=False, ) if use_connection is None and interfaces: From 17536fcf5a0d5ed2de34351450532634eaeda44d Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Wed, 22 May 2019 22:01:33 -0400 Subject: [PATCH 3/4] address @wyattanderson comments --- graphene_sqlalchemy/converter.py | 11 +--- graphene_sqlalchemy/tests/test_types.py | 12 ++-- graphene_sqlalchemy/types.py | 87 +++++++++++++++++++++---- 3 files changed, 86 insertions(+), 24 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index fc6d1b39..2bf32f5e 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -102,14 +102,9 @@ def inner(fn): def convert_sqlalchemy_column(column_prop, registry, **field_kwargs): column = column_prop.columns[0] - if 'type' not in field_kwargs: - field_kwargs['type'] = convert_sqlalchemy_type(getattr(column, "type", None), column, registry) - - if 'required' not in field_kwargs: - field_kwargs['required'] = not is_column_nullable(column) - - if 'description' not in field_kwargs: - field_kwargs['description'] = get_column_doc(column) + field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) + field_kwargs.setdefault('required', not is_column_nullable(column)) + field_kwargs.setdefault('description', get_column_doc(column)) return Field( resolver=_get_attr_resolver(column_prop.key), diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 954e5026..d0e70d80 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -114,14 +114,17 @@ def test_sqlalchemy_override_fields(): def convert_composite_class(composite, registry): return String() - class ReporterType(SQLAlchemyObjectType): + class ReporterMixin(object): + # columns + first_name = ORMField(required=True) + last_name = ORMField(description='Overridden') + + class ReporterType(SQLAlchemyObjectType, ReporterMixin): class Meta: model = Reporter interfaces = (Node,) # columns - first_name = ORMField(required=True) - last_name = ORMField(description='Overridden') email = ORMField(deprecation_reason='Overridden') email_v2 = ORMField(prop_name='email', type=Int) @@ -151,9 +154,10 @@ class Meta: use_connection = False assert list(ReporterType._meta.fields.keys()) == [ - # First the ORMField in the order they were defined + # Fields from ReporterMixin "first_name", "last_name", + # Fields from ReporterType "email", "email_v2", "column_prop", diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index f64d3ba9..a1f30e06 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -27,22 +27,59 @@ class ORMField(OrderedType): def __init__( self, - type=None, prop_name=None, + type=None, + required=None, description=None, deprecation_reason=None, - required=None, _creation_counter=None, **field_kwargs ): + """ + Use this to override fields automatically generated by SQLAlchemyObjectType. + Unless specified, options will default to SQLAlchemyObjectType usual behavior + for the given SQLAlchemy model property. + + Usage: + class MyModel(Base): + id = Column(Integer(), primary_key=True) + name = Column(String) + + class MyType(SQLAlchemyObjectType): + class Meta: + model = MyModel + + id = ORMField(type=graphene.Int) + name = ORMField(required=True) + + -> MyType.id will be of type Int (vs ID). + -> MyType.name will be of type NonNull(String) (vs String). + + Parameters + - prop_name : str, optional + Name of the SQLAlchemy property used to resolve this field. + Default to the name of the attribute referencing the ORMField. + - type: optional + Default to the type mapping in converter.py. + - description: str, optional + Default to the `doc` attribute of the SQLAlchemy column property. + - required: bool, optional + Default to the opposite of the `nullable` attribute of the SQLAlchemy column property. + - description: str, optional + Same behavior as in graphene.Field. Defaults to None. + - deprecation_reason: str, optional + Same behavior as in graphene.Field. Defaults to None. + - _creation_counter: int, optional + Same behavior as in graphene.Field. + """ super(ORMField, self).__init__(_creation_counter=_creation_counter) # The is only useful for documentation and auto-completion common_kwargs = { - 'type': type, - 'prop_name': prop_name, - 'description': description, - 'deprecation_reason': deprecation_reason, - 'required': required, + 'prop_name': prop_name, + 'type': type, + 'required': required, + 'description': description, + 'deprecation_reason': deprecation_reason, } common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} self.kwargs = field_kwargs @@ -52,7 +89,27 @@ def __init__( def construct_fields( obj_type, model, registry, only_fields, exclude_fields, connection_field_factory ): + """ + Construct all the fields for a SQLAlchemyObjectType. + The main steps are: + - Gather all the relevant attributes from the SQLAlchemy model + - Gather all the ORM fields defined on the type + - Merge in overrides and build up all the fields + + Parameters + - obj_type : SQLAlchemyObjectType + - model : the SQLAlchemy model + - registry : Registry + - only_fields : tuple[string] + - exclude_fields : tuple[string] + - connection_field_factory : function + + Returns + - fields + An OrderedDict of field names to graphene.Field + """ inspected_model = sqlalchemyinspect(model) + # Gather all the relevant attributes from the SQLAlchemy model all_model_props = OrderedDict( inspected_model.column_attrs.items() + inspected_model.composites.items() + @@ -61,31 +118,37 @@ def construct_fields( inspected_model.relationships.items() ) + # Filter out excluded fields auto_orm_field_names = [] for prop_name, prop in all_model_props.items(): if (only_fields and prop_name not in only_fields) or (prop_name in exclude_fields): continue auto_orm_field_names.append(prop_name) - # TODO Get ORMField fields defined on parent classes - custom_orm_fields_items = [] - for attname, value in list(obj_type.__dict__.items()): - if isinstance(value, ORMField): - custom_orm_fields_items.append((attname, value)) + # Gather all the ORM fields defined on the type + custom_orm_fields_items = [ + (attname, value) + for base in reversed(obj_type.__mro__) + for attname, value in base.__dict__.items() + if isinstance(value, ORMField) + ] custom_orm_fields_items = sorted(custom_orm_fields_items, key=lambda item: item[1]) + # Set the prop_name if not set for orm_field_name, orm_field in custom_orm_fields_items: prop_name = orm_field.kwargs.get('prop_name', orm_field_name) if prop_name not in all_model_props: raise Exception('Cannot map ORMField "{}" to SQLAlchemy model property'.format(orm_field_name)) orm_field.kwargs['prop_name'] = prop_name + # Merge automatic fields with custom ORM fields orm_fields = OrderedDict(custom_orm_fields_items) for orm_field_name in auto_orm_field_names: if orm_field_name in orm_fields: continue orm_fields[orm_field_name] = ORMField(prop_name=orm_field_name) + # Build all the field dictionary fields = OrderedDict() for orm_field_name, orm_field in orm_fields.items(): prop_name = orm_field.kwargs.pop('prop_name') From 87f5fff85b8497182f184bf528ab1e5b6be6cb1f Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Wed, 5 Jun 2019 10:12:05 -0400 Subject: [PATCH 4/4] Rename param to model_attr --- graphene_sqlalchemy/tests/test_query.py | 10 +-- graphene_sqlalchemy/tests/test_types.py | 2 +- graphene_sqlalchemy/types.py | 99 ++++++++++++------------- 3 files changed, 53 insertions(+), 58 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 269b9bd7..74a7249a 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -176,12 +176,12 @@ class Meta: model = Reporter interfaces = (Node,) - first_name_v2 = ORMField(prop_name='first_name') - hybrid_prop_v2 = ORMField(prop_name='hybrid_prop') - column_prop_v2 = ORMField(prop_name='column_prop') + first_name_v2 = ORMField(model_attr='first_name') + hybrid_prop_v2 = ORMField(model_attr='hybrid_prop') + column_prop_v2 = ORMField(model_attr='column_prop') composite_prop = ORMField() - favorite_article_v2 = ORMField(prop_name='favorite_article') - articles_v2 = ORMField(prop_name='articles') + favorite_article_v2 = ORMField(model_attr='favorite_article') + articles_v2 = ORMField(model_attr='articles') class ArticleType(SQLAlchemyObjectType): class Meta: diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index d0e70d80..bd5d5ae3 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -126,7 +126,7 @@ class Meta: # columns email = ORMField(deprecation_reason='Overridden') - email_v2 = ORMField(prop_name='email', type=Int) + email_v2 = ORMField(model_attr='email', type=Int) # column_property column_prop = ORMField(type=String) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index a1f30e06..f77cbc86 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -27,7 +27,7 @@ class ORMField(OrderedType): def __init__( self, - prop_name=None, + model_attr=None, type=None, required=None, description=None, @@ -55,27 +55,26 @@ class Meta: -> MyType.id will be of type Int (vs ID). -> MyType.name will be of type NonNull(String) (vs String). - Parameters - - prop_name : str, optional - Name of the SQLAlchemy property used to resolve this field. + :param str model_attr: + Name of the SQLAlchemy model attribute used to resolve this field. Default to the name of the attribute referencing the ORMField. - - type: optional + :param type: Default to the type mapping in converter.py. - - description: str, optional + :param str description: Default to the `doc` attribute of the SQLAlchemy column property. - - required: bool, optional + :param bool required: Default to the opposite of the `nullable` attribute of the SQLAlchemy column property. - - description: str, optional + :param str description: Same behavior as in graphene.Field. Defaults to None. - - deprecation_reason: str, optional + :param str deprecation_reason: Same behavior as in graphene.Field. Defaults to None. - - _creation_counter: int, optional + :param int _creation_counter: Same behavior as in graphene.Field. """ super(ORMField, self).__init__(_creation_counter=_creation_counter) # The is only useful for documentation and auto-completion common_kwargs = { - 'prop_name': prop_name, + 'model_attr': model_attr, 'type': type, 'required': required, 'description': description, @@ -92,25 +91,21 @@ def construct_fields( """ Construct all the fields for a SQLAlchemyObjectType. The main steps are: - - Gather all the relevant attributes from the SQLAlchemy model - - Gather all the ORM fields defined on the type - - Merge in overrides and build up all the fields - - Parameters - - obj_type : SQLAlchemyObjectType - - model : the SQLAlchemy model - - registry : Registry - - only_fields : tuple[string] - - exclude_fields : tuple[string] - - connection_field_factory : function - - Returns - - fields - An OrderedDict of field names to graphene.Field + - Gather all the relevant attributes from the SQLAlchemy model + - Gather all the ORM fields defined on the type + - Merge in overrides and build up all the fields + + :param SQLAlchemyObjectType obj_type: + :param model: the SQLAlchemy model + :param Registry registry: + :param tuple[string] only_fields: + :param tuple[string] exclude_fields: + :param function connection_field_factory: + :rtype: OrderedDict[str, graphene.Field] """ inspected_model = sqlalchemyinspect(model) - # Gather all the relevant attributes from the SQLAlchemy model - all_model_props = OrderedDict( + # Gather all the relevant attributes from the SQLAlchemy model in order + all_model_attrs = OrderedDict( inspected_model.column_attrs.items() + inspected_model.composites.items() + [(name, item) for name, item in inspected_model.all_orm_descriptors.items() @@ -120,57 +115,57 @@ def construct_fields( # Filter out excluded fields auto_orm_field_names = [] - for prop_name, prop in all_model_props.items(): - if (only_fields and prop_name not in only_fields) or (prop_name in exclude_fields): + for attr_name, attr in all_model_attrs.items(): + if (only_fields and attr_name not in only_fields) or (attr_name in exclude_fields): continue - auto_orm_field_names.append(prop_name) + auto_orm_field_names.append(attr_name) # Gather all the ORM fields defined on the type custom_orm_fields_items = [ - (attname, value) + (attn_name, attr) for base in reversed(obj_type.__mro__) - for attname, value in base.__dict__.items() - if isinstance(value, ORMField) + for attn_name, attr in base.__dict__.items() + if isinstance(attr, ORMField) ] custom_orm_fields_items = sorted(custom_orm_fields_items, key=lambda item: item[1]) - # Set the prop_name if not set + # Set the model_attr if not set for orm_field_name, orm_field in custom_orm_fields_items: - prop_name = orm_field.kwargs.get('prop_name', orm_field_name) - if prop_name not in all_model_props: + attr_name = orm_field.kwargs.get('model_attr', orm_field_name) + if attr_name not in all_model_attrs: raise Exception('Cannot map ORMField "{}" to SQLAlchemy model property'.format(orm_field_name)) - orm_field.kwargs['prop_name'] = prop_name + orm_field.kwargs['model_attr'] = attr_name # Merge automatic fields with custom ORM fields orm_fields = OrderedDict(custom_orm_fields_items) for orm_field_name in auto_orm_field_names: if orm_field_name in orm_fields: continue - orm_fields[orm_field_name] = ORMField(prop_name=orm_field_name) + orm_fields[orm_field_name] = ORMField(model_attr=orm_field_name) # Build all the field dictionary fields = OrderedDict() for orm_field_name, orm_field in orm_fields.items(): - prop_name = orm_field.kwargs.pop('prop_name') - prop = all_model_props[prop_name] - - if isinstance(prop, ColumnProperty): - field = convert_sqlalchemy_column(prop, registry, **orm_field.kwargs) - elif isinstance(prop, RelationshipProperty): - field = convert_sqlalchemy_relationship(prop, registry, connection_field_factory, **orm_field.kwargs) - elif isinstance(prop, CompositeProperty): - if prop_name != orm_field_name or orm_field.kwargs: + attr_name = orm_field.kwargs.pop('model_attr') + attr = all_model_attrs[attr_name] + + if isinstance(attr, ColumnProperty): + field = convert_sqlalchemy_column(attr, registry, **orm_field.kwargs) + elif isinstance(attr, RelationshipProperty): + field = convert_sqlalchemy_relationship(attr, registry, connection_field_factory, **orm_field.kwargs) + elif isinstance(attr, CompositeProperty): + if attr_name != orm_field_name or orm_field.kwargs: # TODO Add a way to override composite property fields raise ValueError( "ORMField kwargs for composite fields must be empty. " "Field: {}.{}".format(obj_type.__name__, orm_field_name)) - field = convert_sqlalchemy_composite(prop, registry) - elif isinstance(prop, hybrid_property): - field = convert_sqlalchemy_hybrid_method(prop, prop_name, **orm_field.kwargs) + field = convert_sqlalchemy_composite(attr, registry) + elif isinstance(attr, hybrid_property): + field = convert_sqlalchemy_hybrid_method(attr, attr_name, **orm_field.kwargs) else: raise Exception('Property type is not supported') # Should never happen - registry.register_orm_field(obj_type, orm_field_name, prop) + registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field return fields