Skip to content

Added support for middleware #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[flake8]
exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs
max-line-length = 88
20 changes: 15 additions & 5 deletions graphql/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@
"""

from .execute import (
execute, default_field_resolver, response_path_as_list,
ExecutionContext, ExecutionResult)
execute,
default_field_resolver,
response_path_as_list,
ExecutionContext,
ExecutionResult,
)
from .middleware import MiddlewareManager
from .values import get_directive_values

__all__ = [
'execute', 'default_field_resolver', 'response_path_as_list',
'ExecutionContext', 'ExecutionResult',
'get_directive_values']
"execute",
"default_field_resolver",
"response_path_as_list",
"ExecutionContext",
"ExecutionResult",
"MiddlewareManager",
"get_directive_values",
]
31 changes: 27 additions & 4 deletions graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
is_non_null_type, is_object_type)
from .values import (
get_argument_values, get_directive_values, get_variable_values)
from .middleware import MiddlewareManager


__all__ = [
'add_path', 'assert_valid_execution_arguments', 'default_field_resolver',
Expand Down Expand Up @@ -64,7 +66,8 @@ def execute(
schema: GraphQLSchema, document: DocumentNode,
root_value: Any=None, context_value: Any=None,
variable_values: Dict[str, Any]=None,
operation_name: str=None, field_resolver: GraphQLFieldResolver=None
operation_name: str=None, field_resolver: GraphQLFieldResolver=None,
middleware: Optional[Union[Iterable[Any], MiddlewareManager]]=None
) -> MaybeAwaitable[ExecutionResult]:
"""Execute a GraphQL operation.

Expand All @@ -84,7 +87,7 @@ def execute(
# arguments, a "Response" with only errors is returned.
exe_context = ExecutionContext.build(
schema, document, root_value, context_value,
variable_values, operation_name, field_resolver)
variable_values, operation_name, field_resolver, middleware)

# Return early errors if execution context failed.
if isinstance(exe_context, list):
Expand Down Expand Up @@ -116,6 +119,7 @@ class ExecutionContext:
operation: OperationDefinitionNode
variable_values: Dict[str, Any]
field_resolver: GraphQLFieldResolver
middleware_manager: Optional[MiddlewareManager]
errors: List[GraphQLError]

def __init__(
Expand All @@ -125,6 +129,7 @@ def __init__(
operation: OperationDefinitionNode,
variable_values: Dict[str, Any],
field_resolver: GraphQLFieldResolver,
middleware_manager: Optional[MiddlewareManager],
errors: List[GraphQLError]) -> None:
self.schema = schema
self.fragments = fragments
Expand All @@ -133,6 +138,7 @@ def __init__(
self.operation = operation
self.variable_values = variable_values
self.field_resolver = field_resolver # type: ignore
self.middleware_manager = middleware_manager
self.errors = errors
self._subfields_cache: Dict[
Tuple[GraphQLObjectType, Tuple[FieldNode, ...]],
Expand All @@ -144,7 +150,8 @@ def build(
root_value: Any=None, context_value: Any=None,
raw_variable_values: Dict[str, Any]=None,
operation_name: str=None,
field_resolver: GraphQLFieldResolver=None
field_resolver: GraphQLFieldResolver=None,
middleware: Optional[Union[Iterable[Any], MiddlewareManager]]=None
) -> Union[List[GraphQLError], 'ExecutionContext']:
"""Build an execution context

Expand All @@ -157,6 +164,18 @@ def build(
operation: Optional[OperationDefinitionNode] = None
has_multiple_assumed_operations = False
fragments: Dict[str, FragmentDefinitionNode] = {}
middleware_manager: Optional[MiddlewareManager] = None
if middleware:
if isinstance(middleware, Iterable):
middleware_manager = MiddlewareManager(*middleware)
elif isinstance(middleware, MiddlewareManager):
middleware_manager = middleware
else:
raise TypeError(
f"middlewares have to be an instance"
"of MiddlewareManager. Received \"{middleware}\""
)

for definition in document.definitions:
if isinstance(definition, OperationDefinitionNode):
if not operation_name and operation:
Expand Down Expand Up @@ -201,7 +220,8 @@ def build(

return cls(
schema, fragments, root_value, context_value, operation,
variable_values, field_resolver or default_field_resolver, errors)
variable_values, field_resolver or default_field_resolver,
middleware_manager, errors)

def build_response(
self, data: MaybeAwaitable[Optional[Dict[str, Any]]]
Expand Down Expand Up @@ -447,6 +467,9 @@ def resolve_field(

resolve_fn = field_def.resolve or self.field_resolver

if self.middleware_manager:
resolve_fn = self.middleware_manager.get_field_resolver(resolve_fn)

info = self.build_resolve_info(
field_def, field_nodes, parent_type, path)

Expand Down
76 changes: 76 additions & 0 deletions graphql/execution/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Callable, Iterator, Dict, Tuple, Any, Iterable, Optional, cast

from inspect import isfunction
from functools import partial
from itertools import chain


from ..type import GraphQLFieldResolver


__all__ = ["MiddlewareManager", "middlewares"]

# If the provided middleware is a class, this is the attribute we will look at
MIDDLEWARE_RESOLVER_FUNCTION = "resolve"


class MiddlewareManager:
"""MiddlewareManager helps to chain resolver functions with the provided
middleware functions and classes
"""

__slots__ = ("middlewares", "_middleware_resolvers", "_cached_resolvers")

_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
_middleware_resolvers: Optional[Tuple[Callable, ...]]

def __init__(self, *middlewares: Any) -> None:
self.middlewares = middlewares
if middlewares:
self._middleware_resolvers = tuple(get_middleware_resolvers(middlewares))
else:
self.__middleware_resolvers = None
self._cached_resolvers = {}

def get_field_resolver(
self, field_resolver: GraphQLFieldResolver
) -> GraphQLFieldResolver:
"""Wraps the provided resolver returning a function that
executes chains the middleware functions with the resolver function"""
if self._middleware_resolvers is None:
return field_resolver
if field_resolver not in self._cached_resolvers:
self._cached_resolvers[field_resolver] = middleware_chain(
field_resolver, self._middleware_resolvers
)

return self._cached_resolvers[field_resolver]


middlewares = MiddlewareManager


def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
"""Returns the functions related to the middleware classes or functions"""
for middleware in middlewares:
# If the middleware is a function instead of a class
if isfunction(middleware):
yield middleware
resolver_func = getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION, None)
if resolver_func is not None:
yield resolver_func


def middleware_chain(
func: GraphQLFieldResolver, middlewares: Iterable[Callable]
) -> GraphQLFieldResolver:
"""Reduces the current function with the provided middlewares,
returning a new resolver function"""
if not middlewares:
return func
middlewares = chain((func,), middlewares)
last_func: Optional[GraphQLFieldResolver] = None
for middleware in middlewares:
last_func = partial(middleware, last_func) if last_func else middleware

return cast(GraphQLFieldResolver, last_func)
73 changes: 43 additions & 30 deletions graphql/graphql.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from asyncio import ensure_future
from inspect import isawaitable
from typing import Any, Awaitable, Callable, Dict, Union, cast
from typing import Any, Awaitable, Callable, Dict, Union, Optional, Iterable, cast

from .error import GraphQLError
from .execution import execute
from .language import parse, Source
from .pyutils import MaybeAwaitable
from .type import GraphQLSchema, validate_schema
from .execution.execute import ExecutionResult
from .execution import ExecutionResult, MiddlewareManager

__all__ = ['graphql', 'graphql_sync']
__all__ = ["graphql", "graphql_sync"]


async def graphql(
schema: GraphQLSchema,
source: Union[str, Source],
root_value: Any=None,
context_value: Any=None,
variable_values: Dict[str, Any]=None,
operation_name: str=None,
field_resolver: Callable=None) -> ExecutionResult:
schema: GraphQLSchema,
source: Union[str, Source],
root_value: Any = None,
context_value: Any = None,
variable_values: Dict[str, Any] = None,
operation_name: str = None,
field_resolver: Callable = None,
middleware: Optional[Union[Iterable[Any], MiddlewareManager]] = None,
) -> ExecutionResult:
"""Execute a GraphQL operation asynchronously.

This is the primary entry point function for fulfilling GraphQL operations
Expand Down Expand Up @@ -56,6 +58,8 @@ async def graphql(
A resolver function to use when one is not provided by the schema.
If not provided, the default field resolver is used (which looks for
a value or method on the source value with the field's name).
:arg middleware:
The middleware to wrap the resolvers with
"""
# Always return asynchronously for a consistent API.
result = graphql_impl(
Expand All @@ -65,7 +69,9 @@ async def graphql(
context_value,
variable_values,
operation_name,
field_resolver)
field_resolver,
middleware,
)

if isawaitable(result):
return await cast(Awaitable[ExecutionResult], result)
Expand All @@ -74,13 +80,15 @@ async def graphql(


def graphql_sync(
schema: GraphQLSchema,
source: Union[str, Source],
root_value: Any = None,
context_value: Any = None,
variable_values: Dict[str, Any] = None,
operation_name: str = None,
field_resolver: Callable = None) -> ExecutionResult:
schema: GraphQLSchema,
source: Union[str, Source],
root_value: Any = None,
context_value: Any = None,
variable_values: Dict[str, Any] = None,
operation_name: str = None,
field_resolver: Callable = None,
middleware: Optional[Union[Iterable[Any], MiddlewareManager]] = None,
) -> ExecutionResult:
"""Execute a GraphQL operation synchronously.

The graphql_sync function also fulfills GraphQL operations by parsing,
Expand All @@ -95,26 +103,28 @@ def graphql_sync(
context_value,
variable_values,
operation_name,
field_resolver)
field_resolver,
middleware,
)

# Assert that the execution was synchronous.
if isawaitable(result):
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
raise RuntimeError(
'GraphQL execution failed to complete synchronously.')
raise RuntimeError("GraphQL execution failed to complete synchronously.")

return cast(ExecutionResult, result)


def graphql_impl(
schema,
source,
root_value,
context_value,
variable_values,
operation_name,
field_resolver
) -> MaybeAwaitable[ExecutionResult]:
schema,
source,
root_value,
context_value,
variable_values,
operation_name,
field_resolver,
middleware,
) -> MaybeAwaitable[ExecutionResult]:
"""Execute a query, return asynchronously only if necessary."""
# Validate Schema
schema_validation_errors = validate_schema(schema)
Expand All @@ -132,6 +142,7 @@ def graphql_impl(

# Validate
from .validation import validate

validation_errors = validate(schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
Expand All @@ -144,4 +155,6 @@ def graphql_impl(
context_value,
variable_values,
operation_name,
field_resolver)
field_resolver,
middleware,
)
Loading