Skip to content

Allow chaining together multiple dataloader calls #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ query {
}
```

## Chaining dataloaders

The `SyncDataLoader.load` function returns a `SyncFuture` object which, similar to
a JavaScript Promise, allows you to chain results together using the
`then(on_success: Callable)` function.

For example:

```python
def get_user_name(userId: str) -> str:
return user_loader.load(userId).then(lambda user: user["name"])
```

You can also chain together multiple DataLoader calls:

```python
def get_best_friend_name(userId: str) -> str:
return (
user_loader.load(userId)
.then(lambda user: user_loader.load(user["best_friend"]))
.then(lambda best_friend: best_friend["name"])
)
```


## How it works

This library implements a custom version of the graphql-core
Expand Down
21 changes: 2 additions & 19 deletions graphql_sync_dataloaders/execution_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import (
Any,
AsyncIterable,
Callable,
Dict,
Optional,
List,
Expand Down Expand Up @@ -33,6 +32,7 @@
from graphql.execution.values import get_argument_values

from .sync_future import SyncFuture
from .sync_dataloader import dataloader_batch_callbacks


PENDING_FUTURE = object()
Expand All @@ -46,17 +46,12 @@ class DeferredExecutionContext(ExecutionContext):
is executed and before the result is returned.
"""

_deferred_callbacks: List[Callable]

def execute_operation(
self, operation: OperationDefinitionNode, root_value: Any
) -> Optional[AwaitableOrValue[Any]]:
self._deferred_callbacks = []
result = super().execute_operation(operation, root_value)

callbacks = self._deferred_callbacks
while callbacks:
callbacks.pop(0)()
dataloader_batch_callbacks.run_all_callbacks()

if isinstance(result, SyncFuture):
if not result.done():
Expand Down Expand Up @@ -147,10 +142,6 @@ def execute_field(

else:

callback = result.deferred_callback
if callback:
self._deferred_callbacks.append(callback)

# noinspection PyShadowingNames
def process_result(_: Any):
try:
Expand Down Expand Up @@ -261,10 +252,6 @@ def process_result(_: Any):
item_type, field_nodes, info, item_path, item.result()
)
else:
callback = item.deferred_callback
if callback:
self._deferred_callbacks.append(callback)

# noinspection PyShadowingNames
def process_item(
index: int,
Expand Down Expand Up @@ -339,10 +326,6 @@ def process_completed(
if completed.done():
results[index] = completed.result()
else:
callback = completed.deferred_callback
if callback:
self._deferred_callbacks.append(callback)

# noinspection PyShadowingNames
def process_completed(
index: int,
Expand Down
26 changes: 25 additions & 1 deletion graphql_sync_dataloaders/sync_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
from typing import List, Callable
from graphql.pyutils import is_collection

from .sync_future import SyncFuture


class DataloaderBatchCallbacks:
"""
Singleton that stores all the batched callbacks for all dataloaders. This is
equivalent to the async `loop.call_soon` functionality and enables the
batching functionality of dataloaders.
"""
_callbacks: List[Callable]

def __init__(self) -> None:
self._callbacks = []

def add_callback(self, callback: Callable):
self._callbacks.append(callback)

def run_all_callbacks(self):
callbacks = self._callbacks
while callbacks:
callbacks.pop(0)()


dataloader_batch_callbacks = DataloaderBatchCallbacks()


class SyncDataLoader:
def __init__(self, batch_load_fn):
self._batch_load_fn = batch_load_fn
Expand All @@ -17,7 +41,7 @@ def load(self, key):
needs_dispatch = not self._queue
self._queue.append((key, future))
if needs_dispatch:
future.deferred_callback = self.dispatch_queue
dataloader_batch_callbacks.add_callback(self.dispatch_queue)
self._cache[key] = future
return future

Expand Down
83 changes: 82 additions & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from unittest import mock
from unittest.mock import Mock
from functools import partial

Expand Down Expand Up @@ -374,7 +375,7 @@ def resolve_hello(_, __, name):
"name": GraphQLArgument(GraphQLString),
},
resolve=resolve_hello,
)
),
},
)
)
Expand Down Expand Up @@ -402,3 +403,83 @@ def resolve_hello(_, __, name):
keys = list(result.data.keys())
assert keys == ["name1", "hello1", "name2", "hello2"]
assert mock_load_fn.call_count == 1


def test_chaining_dataloader():
USERS = {
"1": {
"name": "Sarah",
"best_friend": "2",
},
"2": {
"name": "Lucy",
"best_friend": "3",
},
"3": {
"name": "Geoff",
},
"5": {
"name": "Dave",
},
}

def load_fn(keys):
return [USERS[key] if key in USERS else None for key in keys]

mock_load_fn = Mock(wraps=load_fn)
dataloader = SyncDataLoader(mock_load_fn)

def resolve_name(_, __, userId):
return dataloader.load(userId).then(lambda user: user["name"])

def resolve_best_friend_name(_, __, userId):
return (
dataloader.load(userId)
.then(lambda user: dataloader.load(user["best_friend"]))
.then(lambda user: user["name"])
)

schema = GraphQLSchema(
query=GraphQLObjectType(
name="Query",
fields={
"name": GraphQLField(
GraphQLString,
args={
"userId": GraphQLArgument(GraphQLString),
},
resolve=resolve_name,
),
"bestFriendName": GraphQLField(
GraphQLString,
args={
"userId": GraphQLArgument(GraphQLString),
},
resolve=resolve_best_friend_name,
),
},
)
)

result = graphql_sync_deferred(
schema,
"""
query {
name1: name(userId: "1")
name2: name(userId: "2")
bestFriend1: bestFriendName(userId: "1")
bestFriend2: bestFriendName(userId: "2")
}
""",
)

assert not result.errors
assert result.data == {
"name1": "Sarah",
"name2": "Lucy",
"bestFriend1": "Lucy",
"bestFriend2": "Geoff",
}
assert mock_load_fn.call_count == 2
assert mock_load_fn.call_args_list[0].args[0] == ["1", "2"]
assert mock_load_fn.call_args_list[1].args[0] == ["3"]