Skip to content

Commit 32d0d18

Browse files
jendrikjoeJendrikerikwrede
authored
feat: support for async sessions (#350)
* feat(async): add support for async sessions This PR brings experimental support for async sessions in SQLAlchemyConnectionFields. Batching is not yet supported and will be subject to a later PR. Co-authored-by: Jendrik <[email protected]> Co-authored-by: Erik Wrede <[email protected]>
1 parent 2edeae9 commit 32d0d18

18 files changed

+931
-332
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: Tests
22

3-
on:
3+
on:
44
push:
55
branches:
66
- 'master'

docs/inheritance.rst

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Inheritance Examples
33

44
Create interfaces from inheritance relationships
55
------------------------------------------------
6-
6+
.. note:: If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_.
77
SQLAlchemy has excellent support for class inheritance hierarchies.
88
These hierarchies can be represented in your GraphQL schema by means
99
of interfaces_. Much like ObjectTypes, Interfaces in
@@ -40,7 +40,7 @@ from the attributes of their underlying SQLAlchemy model:
4040
__mapper_args__ = {
4141
"polymorphic_identity": "employee",
4242
}
43-
43+
4444
class Customer(Person):
4545
first_purchase_date = Column(Date())
4646
@@ -56,17 +56,17 @@ from the attributes of their underlying SQLAlchemy model:
5656
class Meta:
5757
model = Employee
5858
interfaces = (relay.Node, PersonType)
59-
59+
6060
class CustomerType(SQLAlchemyObjectType):
6161
class Meta:
6262
model = Customer
6363
interfaces = (relay.Node, PersonType)
6464
65-
Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must
66-
be linked to an abstract Model that does not specify a `polymorphic_identity`,
67-
because we cannot return instances of interfaces from a GraphQL query.
68-
If Person specified a `polymorphic_identity`, instances of Person could
69-
be inserted into and returned by the database, potentially causing
65+
Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must
66+
be linked to an abstract Model that does not specify a `polymorphic_identity`,
67+
because we cannot return instances of interfaces from a GraphQL query.
68+
If Person specified a `polymorphic_identity`, instances of Person could
69+
be inserted into and returned by the database, potentially causing
7070
Persons to be returned to the resolvers.
7171

7272
When querying on the base type, you can refer directly to common fields,
@@ -85,15 +85,19 @@ and fields on concrete implementations using the `... on` syntax:
8585
firstPurchaseDate
8686
}
8787
}
88-
89-
88+
89+
90+
.. danger::
91+
When using joined table inheritance, this style of querying may lead to unbatched implicit IO with negative performance implications.
92+
See the chapter `Eager Loading & Using with AsyncSession`_ for more information on eager loading all possible types of a `SQLAlchemyInterface`.
93+
9094
Please note that by default, the "polymorphic_on" column is *not*
9195
generated as a field on types that use polymorphic inheritance, as
92-
this is considered an implentation detail. The idiomatic way to
96+
this is considered an implementation detail. The idiomatic way to
9397
retrieve the concrete GraphQL type of an object is to query for the
94-
`__typename` field.
98+
`__typename` field.
9599
To override this behavior, an `ORMField` needs to be created
96-
for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended*
100+
for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended*
97101
as it promotes abiguous schema design
98102

99103
If your SQLAlchemy model only specifies a relationship to the
@@ -103,5 +107,39 @@ class to the Schema constructor via the `types=` argument:
103107
.. code:: python
104108
105109
schema = graphene.Schema(..., types=[PersonType, EmployeeType, CustomerType])
106-
110+
111+
107112
See also: `Graphene Interfaces <https://docs.graphene-python.org/en/latest/types/interfaces/>`_
113+
114+
Eager Loading & Using with AsyncSession
115+
--------------------
116+
When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly.
117+
This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables.
118+
To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model:
119+
120+
.. code:: python
121+
class Person(Base):
122+
id = Column(Integer(), primary_key=True)
123+
type = Column(String())
124+
name = Column(String())
125+
birth_date = Column(Date())
126+
127+
__tablename__ = "person"
128+
__mapper_args__ = {
129+
"polymorphic_on": type,
130+
"with_polymorphic": "*", # needed for eager loading in async session
131+
}
132+
133+
Alternatively, the specific polymorphic fields can be loaded explicitly in resolvers:
134+
135+
.. code:: python
136+
137+
class Query(graphene.ObjectType):
138+
people = graphene.Field(graphene.List(PersonType))
139+
140+
async def resolve_people(self, _info):
141+
return (await session.scalars(with_polymorphic(Person, [Engineer, Customer]))).all()
142+
143+
Dynamic batching of the types based on the query to avoid eager is currently not supported, but could be implemented in a future PR.
144+
145+
For more information on loading techniques for polymorphic models, please check out the `SQLAlchemy docs <https://docs.sqlalchemy.org/en/20/orm/queryguide/inheritance.html>`_.

graphene_sqlalchemy/batching.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlalchemy.orm import Session, strategies
77
from sqlalchemy.orm.query import QueryContext
88

9-
from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than
9+
from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than
1010

1111

1212
def get_data_loader_impl() -> Any: # pragma: no cover
@@ -71,19 +71,19 @@ async def batch_load_fn(self, parents):
7171

7272
# For our purposes, the query_context will only used to get the session
7373
query_context = None
74-
if is_sqlalchemy_version_less_than("1.4"):
75-
query_context = QueryContext(session.query(parent_mapper.entity))
76-
else:
74+
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
7775
parent_mapper_query = session.query(parent_mapper.entity)
7876
query_context = parent_mapper_query._compile_context()
79-
80-
if is_sqlalchemy_version_less_than("1.4"):
77+
else:
78+
query_context = QueryContext(session.query(parent_mapper.entity))
79+
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
8180
self.selectin_loader._load_for_path(
8281
query_context,
8382
parent_mapper._path_registry,
8483
states,
8584
None,
8685
child_mapper,
86+
None,
8787
)
8888
else:
8989
self.selectin_loader._load_for_path(
@@ -92,7 +92,6 @@ async def batch_load_fn(self, parents):
9292
states,
9393
None,
9494
child_mapper,
95-
None,
9695
)
9796
return [getattr(parent, self.relationship_prop.key) for parent in parents]
9897

graphene_sqlalchemy/fields.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from graphql_relay import connection_from_array_slice
1212

1313
from .batching import get_batch_resolver
14-
from .utils import EnumValue, get_query
14+
from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session
15+
16+
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
17+
from sqlalchemy.ext.asyncio import AsyncSession
1518

1619

1720
class SQLAlchemyConnectionField(ConnectionField):
@@ -81,8 +84,49 @@ def get_query(cls, model, info, sort=None, **args):
8184

8285
@classmethod
8386
def resolve_connection(cls, connection_type, model, info, args, resolved):
87+
session = get_session(info.context)
88+
if resolved is None:
89+
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
90+
91+
async def get_result():
92+
return await cls.resolve_connection_async(
93+
connection_type, model, info, args, resolved
94+
)
95+
96+
return get_result()
97+
98+
else:
99+
resolved = cls.get_query(model, info, **args)
100+
if isinstance(resolved, Query):
101+
_len = resolved.count()
102+
else:
103+
_len = len(resolved)
104+
105+
def adjusted_connection_adapter(edges, pageInfo):
106+
return connection_adapter(connection_type, edges, pageInfo)
107+
108+
connection = connection_from_array_slice(
109+
array_slice=resolved,
110+
args=args,
111+
slice_start=0,
112+
array_length=_len,
113+
array_slice_length=_len,
114+
connection_type=adjusted_connection_adapter,
115+
edge_type=connection_type.Edge,
116+
page_info_type=page_info_adapter,
117+
)
118+
connection.iterable = resolved
119+
connection.length = _len
120+
return connection
121+
122+
@classmethod
123+
async def resolve_connection_async(
124+
cls, connection_type, model, info, args, resolved
125+
):
126+
session = get_session(info.context)
84127
if resolved is None:
85-
resolved = cls.get_query(model, info, **args)
128+
query = cls.get_query(model, info, **args)
129+
resolved = (await session.scalars(query)).all()
86130
if isinstance(resolved, Query):
87131
_len = resolved.count()
88132
else:
@@ -179,7 +223,7 @@ def from_relationship(cls, relationship, registry, **field_kwargs):
179223
return cls(
180224
model_type.connection,
181225
resolver=get_batch_resolver(relationship),
182-
**field_kwargs
226+
**field_kwargs,
183227
)
184228

185229

graphene_sqlalchemy/tests/conftest.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import pytest
2+
import pytest_asyncio
23
from sqlalchemy import create_engine
34
from sqlalchemy.orm import sessionmaker
45

56
import graphene
7+
from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4
68

79
from ..converter import convert_sqlalchemy_composite
810
from ..registry import reset_global_registry
911
from .models import Base, CompositeFullName
1012

11-
test_db_url = "sqlite://" # use in-memory database for tests
13+
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
14+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
1215

1316

1417
@pytest.fixture(autouse=True)
@@ -22,18 +25,49 @@ def convert_composite_class(composite, registry):
2225
return graphene.Field(graphene.Int)
2326

2427

25-
@pytest.fixture(scope="function")
26-
def session_factory():
27-
engine = create_engine(test_db_url)
28-
Base.metadata.create_all(engine)
28+
@pytest.fixture(params=[False, True])
29+
def async_session(request):
30+
return request.param
31+
32+
33+
@pytest.fixture
34+
def test_db_url(async_session: bool):
35+
if async_session:
36+
return "sqlite+aiosqlite://"
37+
else:
38+
return "sqlite://"
2939

30-
yield sessionmaker(bind=engine)
3140

41+
@pytest.mark.asyncio
42+
@pytest_asyncio.fixture(scope="function")
43+
async def session_factory(async_session: bool, test_db_url: str):
44+
if async_session:
45+
if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
46+
pytest.skip("Async Sessions only work in sql alchemy 1.4 and above")
47+
engine = create_async_engine(test_db_url)
48+
async with engine.begin() as conn:
49+
await conn.run_sync(Base.metadata.create_all)
50+
yield sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
51+
await engine.dispose()
52+
else:
53+
engine = create_engine(test_db_url)
54+
Base.metadata.create_all(engine)
55+
yield sessionmaker(bind=engine, expire_on_commit=False)
56+
# SQLite in-memory db is deleted when its connection is closed.
57+
# https://www.sqlite.org/inmemorydb.html
58+
engine.dispose()
59+
60+
61+
@pytest_asyncio.fixture(scope="function")
62+
async def sync_session_factory():
63+
engine = create_engine("sqlite://")
64+
Base.metadata.create_all(engine)
65+
yield sessionmaker(bind=engine, expire_on_commit=False)
3266
# SQLite in-memory db is deleted when its connection is closed.
3367
# https://www.sqlite.org/inmemorydb.html
3468
engine.dispose()
3569

3670

37-
@pytest.fixture(scope="function")
71+
@pytest_asyncio.fixture(scope="function")
3872
def session(session_factory):
3973
return session_factory()

graphene_sqlalchemy/tests/models.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from sqlalchemy.ext.declarative import declarative_base
2222
from sqlalchemy.ext.hybrid import hybrid_property
23-
from sqlalchemy.orm import column_property, composite, mapper, relationship
23+
from sqlalchemy.orm import backref, column_property, composite, mapper, relationship
2424

2525
PetKind = Enum("cat", "dog", name="pet_kind")
2626

@@ -76,10 +76,16 @@ class Reporter(Base):
7676
email = Column(String(), doc="Email")
7777
favorite_pet_kind = Column(PetKind)
7878
pets = relationship(
79-
"Pet", secondary=association_table, backref="reporters", order_by="Pet.id"
79+
"Pet",
80+
secondary=association_table,
81+
backref="reporters",
82+
order_by="Pet.id",
83+
lazy="selectin",
8084
)
81-
articles = relationship("Article", backref="reporter")
82-
favorite_article = relationship("Article", uselist=False)
85+
articles = relationship(
86+
"Article", backref=backref("reporter", lazy="selectin"), lazy="selectin"
87+
)
88+
favorite_article = relationship("Article", uselist=False, lazy="selectin")
8389

8490
@hybrid_property
8591
def hybrid_prop_with_doc(self):
@@ -304,8 +310,10 @@ class Person(Base):
304310
__tablename__ = "person"
305311
__mapper_args__ = {
306312
"polymorphic_on": type,
313+
"with_polymorphic": "*", # needed for eager loading in async session
307314
}
308315

316+
309317
class NonAbstractPerson(Base):
310318
id = Column(Integer(), primary_key=True)
311319
type = Column(String())
@@ -318,6 +326,7 @@ class NonAbstractPerson(Base):
318326
"polymorphic_identity": "person",
319327
}
320328

329+
321330
class Employee(Person):
322331
hire_date = Column(Date())
323332

0 commit comments

Comments
 (0)