diff --git a/graphql/execution/base.py b/graphql/execution/base.py index d71151a5..33f449d0 100644 --- a/graphql/execution/base.py +++ b/graphql/execution/base.py @@ -60,6 +60,9 @@ def to_dict(self, format_error=None, dict_class=OrderedDict): if not self.invalid: response["data"] = self.data + if self.extensions: + response["extensions"] = self.extensions + return response @@ -76,6 +79,7 @@ class ResolveInfo(object): "variable_values", "context", "path", + "extensions", ) def __init__( @@ -91,6 +95,7 @@ def __init__( variable_values, # type: Dict context, # type: Optional[Any] path=None, # type: Union[List[Union[int, str]], List[str]] + extensions=None, # type: Dict ): # type: (...) -> None self.field_name = field_name @@ -104,6 +109,7 @@ def __init__( self.variable_values = variable_values self.context = context self.path = path + self.extensions = extensions __all__ = [ diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 1b5e884e..e6b75842 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -133,11 +133,17 @@ def on_resolve(data): if isinstance(data, Observable): return data - if not exe_context.errors: + if exe_context.errors and exe_context.extensions: + return ExecutionResult( + data=data, errors=exe_context.errors, extensions=exe_context.extensions + ) + elif exe_context.errors: + return ExecutionResult(data=data, errors=exe_context.errors) + elif exe_context.extensions: + return ExecutionResult(data=data, extensions=exe_context.extensions) + else: return ExecutionResult(data=data) - return ExecutionResult(data=data, errors=exe_context.errors) - promise = ( Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) ) @@ -354,6 +360,7 @@ def resolve_field( variable_values=exe_context.variable_values, context=context, path=field_path, + extensions=exe_context.extensions, ) executor = exe_context.executor @@ -408,6 +415,7 @@ def subscribe_field( variable_values=exe_context.variable_values, context=context, path=path, + extensions=exe_context.extensions, ) executor = exe_context.executor @@ -531,6 +539,20 @@ def complete_value( ), ) + # If result is ExecutionResult, update exe_context and complete for data field + if isinstance(result, ExecutionResult): + data = getattr(result, "data", None) + extensions = getattr(result, "extensions", None) + errors = getattr(result, "errors", None) + + if extensions: + exe_context.update_extensions(extensions) + if errors: + for error in errors: + exe_context.report_error(error) + + return complete_value(exe_context, return_type, field_asts, info, path, data) + # print return_type, type(result) if isinstance(result, Exception): raise GraphQLLocatedError(field_asts, original_error=result, path=path) diff --git a/graphql/execution/tests/test_resolve.py b/graphql/execution/tests/test_resolve.py index d6788f38..e66660d0 100644 --- a/graphql/execution/tests/test_resolve.py +++ b/graphql/execution/tests/test_resolve.py @@ -15,6 +15,7 @@ GraphQLSchema, GraphQLString, ) +from graphql.execution import ExecutionResult from promise import Promise # from graphql.execution.base import ResolveInfo @@ -112,6 +113,40 @@ def resolver(source, info, **args): ] +def test_handles_resolved_extensions_with_data(): + # type: () -> None + def resolver(source, info, **args): + # type: (Optional[str], ResolveInfo, **Any) -> ExecutionResult + extensions = info.extensions or {} + extensions["test_extensions"] = extensions.get("test_extensions", {}) + extensions["test_extensions"].update({"foo": "bar"}) + return ExecutionResult(data="foobar", extensions=extensions) + + schema = _test_schema(GraphQLField(GraphQLString, resolver=resolver)) + + result = graphql(schema, "{ test }", None) + assert not result.errors + assert result.data == {"test": "foobar"} + assert result.extensions == {"test_extensions": {"foo": "bar"}} + + +def test_handles_resolved_extensions_with_errors(): + # type: () -> None + def resolver(source, info, **args): + # type: (Optional[str], ResolveInfo, **Any) -> ExecutionResult + extensions = info.extensions or {} + extensions["errors"] = extensions.get("errors", {}) + extensions["errors"].update({"test": {"foo": "bar"}}) + return ExecutionResult(errors=[Exception()], extensions=extensions) + + schema = _test_schema(GraphQLField(GraphQLString, resolver=resolver)) + + result = graphql(schema, "{ test }", None) + assert len(result.errors) == 1 + assert result.data == {"test": None} + assert result.extensions == {"errors": {"test": {"foo": "bar"}}} + + def test_handles_resolved_promises(): # type: () -> None def resolver(source, info, **args): @@ -125,6 +160,23 @@ def resolver(source, info, **args): assert result.data == {"test": "foo"} +def test_handles_resolved_promises_extensions(): + # type: () -> None + def resolver(source, info, **args): + # type: (Optional[Any], ResolveInfo, **Any) -> Promise + extensions = info.extensions or {} + extensions["test_extensions"] = extensions.get("test_extensions", {}) + extensions["test_extensions"].update({"foo": "bar"}) + return Promise.resolve(ExecutionResult(data="foobar", extensions=extensions)) + + schema = _test_schema(GraphQLField(GraphQLString, resolver=resolver)) + + result = graphql(schema, "{ test }", None) + assert not result.errors + assert result.data == {"test": "foobar"} + assert result.extensions == {"test_extensions": {"foo": "bar"}} + + def test_handles_resolved_custom_promises(): # type: () -> None def resolver(source, info, **args): diff --git a/graphql/execution/utils.py b/graphql/execution/utils.py index b1e7ff25..13622e39 100644 --- a/graphql/execution/utils.py +++ b/graphql/execution/utils.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from copy import deepcopy import logging from traceback import format_exception @@ -54,6 +55,7 @@ class ExecutionContext(object): "middleware", "allow_subscriptions", "_subfields_cache", + "extensions", ) def __init__( @@ -67,6 +69,7 @@ def __init__( executor, # type: Any middleware, # type: Optional[Any] allow_subscriptions, # type: bool + extensions=None, # type: Dict ): # type: (...) -> None """Constructs a ExecutionContext object from the arguments passed @@ -126,6 +129,7 @@ def __init__( self.middleware = middleware self.allow_subscriptions = allow_subscriptions self._subfields_cache = {} # type: Dict[Tuple[GraphQLObjectType, Tuple[Field, ...]], DefaultOrderedDict] + self.extensions = extensions def get_field_resolver(self, field_resolver): # type: (Callable) -> Callable @@ -151,6 +155,12 @@ def report_error(self, error, traceback=None): logger.error("".join(exception)) self.errors.append(error) + def update_extensions(self, extensions): + # type: (Dict[str, Any]) -> None + if extensions: + self.extensions = self.extensions or {} + self.extensions.update(extensions) + def get_sub_fields(self, return_type, field_asts): # type: (GraphQLObjectType, List[Field]) -> DefaultOrderedDict k = return_type, tuple(field_asts)