Skip to content

Commit e27f19b

Browse files
committed
Allow chaining together multiple dataloader calls
This PR changes how the dataloader batching functionality works. Rather than trying to "detect" if a future has a "deferred_callback" in the execution context, we create a singleton `DataLoaderBatchCallbacks` which all dataloaders add their `dispatch_queue` functions to when needed. We can then run all the callbacks in the execution context to complete the SyncFuture's. This allows us to chain dataloaders together. Fixes #6 Diff-Id: daffd
1 parent 6cb666b commit e27f19b

File tree

4 files changed

+134
-21
lines changed

4 files changed

+134
-21
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,31 @@ query {
162162
}
163163
```
164164

165+
## Chaining dataloaders
166+
167+
The `SyncDataLoader.load` function returns a `SyncFuture` object which, similar to
168+
a JavaScript Promise, allows you to chain results together using the
169+
`then(on_success: Callable)` function.
170+
171+
For example:
172+
173+
```python
174+
def get_user_name(userId: str) -> str:
175+
return user_loader.load(userId).then(lambda user: user["name"])
176+
```
177+
178+
You can also chain together multiple DataLoader calls:
179+
180+
```python
181+
def get_best_friend_name(userId: str) -> str:
182+
return (
183+
user_loader.load(userId)
184+
.then(lambda user: user_loader.load(user["best_friend"]))
185+
.then(lambda best_friend: best_friend["name"])
186+
)
187+
```
188+
189+
165190
## How it works
166191

167192
This library implements a custom version of the graphql-core

graphql_sync_dataloaders/execution_context.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import (
22
Any,
33
AsyncIterable,
4-
Callable,
54
Dict,
65
Optional,
76
List,
@@ -33,6 +32,7 @@
3332
from graphql.execution.values import get_argument_values
3433

3534
from .sync_future import SyncFuture
35+
from .sync_dataloader import dataloader_batch_callbacks
3636

3737

3838
PENDING_FUTURE = object()
@@ -46,17 +46,12 @@ class DeferredExecutionContext(ExecutionContext):
4646
is executed and before the result is returned.
4747
"""
4848

49-
_deferred_callbacks: List[Callable]
50-
5149
def execute_operation(
5250
self, operation: OperationDefinitionNode, root_value: Any
5351
) -> Optional[AwaitableOrValue[Any]]:
54-
self._deferred_callbacks = []
5552
result = super().execute_operation(operation, root_value)
5653

57-
callbacks = self._deferred_callbacks
58-
while callbacks:
59-
callbacks.pop(0)()
54+
dataloader_batch_callbacks.run_all_callbacks()
6055

6156
if isinstance(result, SyncFuture):
6257
if not result.done():
@@ -147,10 +142,6 @@ def execute_field(
147142

148143
else:
149144

150-
callback = result.deferred_callback
151-
if callback:
152-
self._deferred_callbacks.append(callback)
153-
154145
# noinspection PyShadowingNames
155146
def process_result(_: Any):
156147
try:
@@ -261,10 +252,6 @@ def process_result(_: Any):
261252
item_type, field_nodes, info, item_path, item.result()
262253
)
263254
else:
264-
callback = item.deferred_callback
265-
if callback:
266-
self._deferred_callbacks.append(callback)
267-
268255
# noinspection PyShadowingNames
269256
def process_item(
270257
index: int,
@@ -339,10 +326,6 @@ def process_completed(
339326
if completed.done():
340327
results[index] = completed.result()
341328
else:
342-
callback = completed.deferred_callback
343-
if callback:
344-
self._deferred_callbacks.append(callback)
345-
346329
# noinspection PyShadowingNames
347330
def process_completed(
348331
index: int,

graphql_sync_dataloaders/sync_dataloader.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,32 @@
1+
from typing import List, Callable
12
from graphql.pyutils import is_collection
23

34
from .sync_future import SyncFuture
45

56

7+
class DataloaderBatchCallbacks:
8+
"""
9+
Singleton that stores all the batched callbacks for all dataloaders. This is
10+
equivalent to the async `loop.call_soon` functionality and enables the
11+
batching functionality of dataloaders.
12+
"""
13+
_callbacks: List[Callable]
14+
15+
def __init__(self) -> None:
16+
self._callbacks = []
17+
18+
def add_callback(self, callback: Callable):
19+
self._callbacks.append(callback)
20+
21+
def run_all_callbacks(self):
22+
callbacks = self._callbacks
23+
while callbacks:
24+
callbacks.pop(0)()
25+
26+
27+
dataloader_batch_callbacks = DataloaderBatchCallbacks()
28+
29+
630
class SyncDataLoader:
731
def __init__(self, batch_load_fn):
832
self._batch_load_fn = batch_load_fn
@@ -17,7 +41,7 @@ def load(self, key):
1741
needs_dispatch = not self._queue
1842
self._queue.append((key, future))
1943
if needs_dispatch:
20-
future.deferred_callback = self.dispatch_queue
44+
dataloader_batch_callbacks.add_callback(self.dispatch_queue)
2145
self._cache[key] = future
2246
return future
2347

tests/test_dataloader.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from unittest import mock
12
from unittest.mock import Mock
23
from functools import partial
34

@@ -374,7 +375,7 @@ def resolve_hello(_, __, name):
374375
"name": GraphQLArgument(GraphQLString),
375376
},
376377
resolve=resolve_hello,
377-
)
378+
),
378379
},
379380
)
380381
)
@@ -402,3 +403,83 @@ def resolve_hello(_, __, name):
402403
keys = list(result.data.keys())
403404
assert keys == ["name1", "hello1", "name2", "hello2"]
404405
assert mock_load_fn.call_count == 1
406+
407+
408+
def test_chaining_dataloader():
409+
USERS = {
410+
"1": {
411+
"name": "Sarah",
412+
"best_friend": "2",
413+
},
414+
"2": {
415+
"name": "Lucy",
416+
"best_friend": "3",
417+
},
418+
"3": {
419+
"name": "Geoff",
420+
},
421+
"5": {
422+
"name": "Dave",
423+
},
424+
}
425+
426+
def load_fn(keys):
427+
return [USERS[key] if key in USERS else None for key in keys]
428+
429+
mock_load_fn = Mock(wraps=load_fn)
430+
dataloader = SyncDataLoader(mock_load_fn)
431+
432+
def resolve_name(_, __, userId):
433+
return dataloader.load(userId).then(lambda user: user["name"])
434+
435+
def resolve_best_friend_name(_, __, userId):
436+
return (
437+
dataloader.load(userId)
438+
.then(lambda user: dataloader.load(user["best_friend"]))
439+
.then(lambda user: user["name"])
440+
)
441+
442+
schema = GraphQLSchema(
443+
query=GraphQLObjectType(
444+
name="Query",
445+
fields={
446+
"name": GraphQLField(
447+
GraphQLString,
448+
args={
449+
"userId": GraphQLArgument(GraphQLString),
450+
},
451+
resolve=resolve_name,
452+
),
453+
"bestFriendName": GraphQLField(
454+
GraphQLString,
455+
args={
456+
"userId": GraphQLArgument(GraphQLString),
457+
},
458+
resolve=resolve_best_friend_name,
459+
),
460+
},
461+
)
462+
)
463+
464+
result = graphql_sync_deferred(
465+
schema,
466+
"""
467+
query {
468+
name1: name(userId: "1")
469+
name2: name(userId: "2")
470+
bestFriend1: bestFriendName(userId: "1")
471+
bestFriend2: bestFriendName(userId: "2")
472+
}
473+
""",
474+
)
475+
476+
assert not result.errors
477+
assert result.data == {
478+
"name1": "Sarah",
479+
"name2": "Lucy",
480+
"bestFriend1": "Lucy",
481+
"bestFriend2": "Geoff",
482+
}
483+
assert mock_load_fn.call_count == 2
484+
assert mock_load_fn.call_args_list[0].args[0] == ["1", "2"]
485+
assert mock_load_fn.call_args_list[1].args[0] == ["3"]

0 commit comments

Comments
 (0)