diff --git a/graphql_server/aiohttp/__init__.py b/graphql_server/aiohttp/__init__.py new file mode 100644 index 0000000..8f5beaf --- /dev/null +++ b/graphql_server/aiohttp/__init__.py @@ -0,0 +1,3 @@ +from .graphqlview import GraphQLView + +__all__ = ["GraphQLView"] diff --git a/graphql_server/aiohttp/graphqlview.py b/graphql_server/aiohttp/graphqlview.py new file mode 100644 index 0000000..9581e12 --- /dev/null +++ b/graphql_server/aiohttp/graphqlview.py @@ -0,0 +1,217 @@ +import copy +from collections.abc import MutableMapping +from functools import partial + +from aiohttp import web +from graphql import GraphQLError +from graphql.type.schema import GraphQLSchema + +from graphql_server import ( + HttpQueryError, + encode_execution_results, + format_error_default, + json_encode, + load_json_body, + run_http_query, +) + +from .render_graphiql import render_graphiql + + +class GraphQLView: + schema = None + root_value = None + context = None + pretty = False + graphiql = False + graphiql_version = None + graphiql_template = None + middleware = None + batch = False + jinja_env = None + max_age = 86400 + enable_async = False + subscriptions = None + + accepted_methods = ["GET", "POST", "PUT", "DELETE"] + + format_error = staticmethod(format_error_default) + encode = staticmethod(json_encode) + + def __init__(self, **kwargs): + super(GraphQLView, self).__init__() + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + + assert isinstance( + self.schema, GraphQLSchema + ), "A Schema is required to be provided to GraphQLView." + + def get_root_value(self): + return self.root_value + + def get_context(self, request): + context = ( + copy.copy(self.context) + if self.context and isinstance(self.context, MutableMapping) + else {} + ) + if isinstance(context, MutableMapping) and "request" not in context: + context.update({"request": request}) + return context + + def get_middleware(self): + return self.middleware + + # This method can be static + async def parse_body(self, request): + content_type = request.content_type + # request.text() is the aiohttp equivalent to + # request.body.decode("utf8") + if content_type == "application/graphql": + r_text = await request.text() + return {"query": r_text} + + if content_type == "application/json": + text = await request.text() + return load_json_body(text) + + if content_type in ( + "application/x-www-form-urlencoded", + "multipart/form-data", + ): + # TODO: seems like a multidict would be more appropriate + # than casting it and de-duping variables. Alas, it's what + # graphql-python wants. + return dict(await request.post()) + + return {} + + def render_graphiql(self, params, result): + return render_graphiql( + jinja_env=self.jinja_env, + params=params, + result=result, + graphiql_version=self.graphiql_version, + graphiql_template=self.graphiql_template, + subscriptions=self.subscriptions, + ) + + # TODO: + # use this method to replace flask and sanic + # checks as this is equivalent to `should_display_graphiql` and + # `request_wants_html` methods. + def is_graphiql(self, request): + return all( + [ + self.graphiql, + request.method.lower() == "get", + "raw" not in request.query, + any( + [ + "text/html" in request.headers.get("accept", {}), + "*/*" in request.headers.get("accept", {}), + ] + ), + ] + ) + + # TODO: Same stuff as above method. + def is_pretty(self, request): + return any( + [self.pretty, self.is_graphiql(request), request.query.get("pretty")] + ) + + async def __call__(self, request): + try: + data = await self.parse_body(request) + request_method = request.method.lower() + is_graphiql = self.is_graphiql(request) + is_pretty = self.is_pretty(request) + + # TODO: way better than if-else so better + # implement this too on flask and sanic + if request_method == "options": + return self.process_preflight(request) + + execution_results, all_params = run_http_query( + self.schema, + request_method, + data, + query_data=request.query, + batch_enabled=self.batch, + catch=is_graphiql, + # Execute options + run_sync=not self.enable_async, + root_value=self.get_root_value(), + context_value=self.get_context(request), + middleware=self.get_middleware(), + ) + + exec_res = ( + [await ex for ex in execution_results] + if self.enable_async + else execution_results + ) + result, status_code = encode_execution_results( + exec_res, + is_batch=isinstance(data, list), + format_error=self.format_error, + encode=partial(self.encode, pretty=is_pretty), # noqa: ignore + ) + + if is_graphiql: + return await self.render_graphiql(params=all_params[0], result=result) + + return web.Response( + text=result, status=status_code, content_type="application/json", + ) + + except HttpQueryError as err: + parsed_error = GraphQLError(err.message) + return web.Response( + body=self.encode(dict(errors=[self.format_error(parsed_error)])), + status=err.status_code, + headers=err.headers, + content_type="application/json", + ) + + def process_preflight(self, request): + """ + Preflight request support for apollo-client + https://www.w3.org/TR/cors/#resource-preflight-requests + """ + headers = request.headers + origin = headers.get("Origin", "") + method = headers.get("Access-Control-Request-Method", "").upper() + + if method and method in self.accepted_methods: + return web.Response( + status=200, + headers={ + "Access-Control-Allow-Origin": origin, + "Access-Control-Allow-Methods": ", ".join(self.accepted_methods), + "Access-Control-Max-Age": str(self.max_age), + }, + ) + return web.Response(status=400) + + @classmethod + def attach(cls, app, *, route_path="/graphql", route_name="graphql", **kwargs): + view = cls(**kwargs) + app.router.add_route("*", route_path, _asyncify(view), name=route_name) + + +def _asyncify(handler): + """Return an async version of the given handler. + + This is mainly here because ``aiohttp`` can't infer the async definition of + :py:meth:`.GraphQLView.__call__` and raises a :py:class:`DeprecationWarning` + in tests. Wrapping it into an async function avoids the noisy warning. + """ + + async def _dispatch(request): + return await handler(request) + + return _dispatch diff --git a/graphql_server/aiohttp/render_graphiql.py b/graphql_server/aiohttp/render_graphiql.py new file mode 100644 index 0000000..9da47d3 --- /dev/null +++ b/graphql_server/aiohttp/render_graphiql.py @@ -0,0 +1,208 @@ +import json +import re + +from aiohttp import web + +GRAPHIQL_VERSION = "0.17.5" + +TEMPLATE = """ + + + + + + + + + + + + + + + + +""" + + +def escape_js_value(value): + quotation = False + if value.startswith('"') and value.endswith('"'): + quotation = True + value = value[1:-1] + + value = value.replace("\\\\n", "\\\\\\n").replace("\\n", "\\\\n") + if quotation: + value = '"' + value.replace('\\\\"', '"').replace('"', '\\"') + '"' + + return value + + +def process_var(template, name, value, jsonify=False): + pattern = r"{{\s*" + name + r"(\s*|[^}]+)*\s*}}" + if jsonify and value not in ["null", "undefined"]: + value = json.dumps(value) + value = escape_js_value(value) + + return re.sub(pattern, value, template) + + +def simple_renderer(template, **values): + replace = ["graphiql_version", "subscriptions"] + replace_jsonify = ["query", "result", "variables", "operation_name"] + + for rep in replace: + template = process_var(template, rep, values.get(rep, "")) + + for rep in replace_jsonify: + template = process_var(template, rep, values.get(rep, ""), True) + + return template + + +async def render_graphiql( + jinja_env=None, + graphiql_version=None, + graphiql_template=None, + params=None, + result=None, + subscriptions=None, +): + graphiql_version = graphiql_version or GRAPHIQL_VERSION + template = graphiql_template or TEMPLATE + template_vars = { + "graphiql_version": graphiql_version, + "query": params and params.query, + "variables": params and params.variables, + "operation_name": params and params.operation_name, + "result": result, + "subscriptions": subscriptions or "", + } + + if jinja_env: + template = jinja_env.from_string(template) + if jinja_env.is_async: + source = await template.render_async(**template_vars) + else: + source = template.render(**template_vars) + else: + source = simple_renderer(template, **template_vars) + + return web.Response(text=source, content_type="text/html") diff --git a/setup.py b/setup.py index fbf8637..6135166 100644 --- a/setup.py +++ b/setup.py @@ -27,10 +27,15 @@ "sanic>=19.9.0,<20", ] +install_aiohttp_requires = [ + "aiohttp>=3.5.0,<4", +] + install_all_requires = \ install_requires + \ install_flask_requires + \ - install_sanic_requires + install_sanic_requires + \ + install_aiohttp_requires setup( name="graphql-server-core", @@ -62,6 +67,7 @@ "dev": install_all_requires + dev_requires, "flask": install_flask_requires, "sanic": install_sanic_requires, + "aiohttp": install_aiohttp_requires, }, include_package_data=True, zip_safe=False, diff --git a/tests/aiohttp/__init__.py b/tests/aiohttp/__init__.py new file mode 100644 index 0000000..943d58f --- /dev/null +++ b/tests/aiohttp/__init__.py @@ -0,0 +1 @@ +# aiohttp-graphql tests diff --git a/tests/aiohttp/app.py b/tests/aiohttp/app.py new file mode 100644 index 0000000..36d7de6 --- /dev/null +++ b/tests/aiohttp/app.py @@ -0,0 +1,22 @@ +from urllib.parse import urlencode + +from aiohttp import web + +from graphql_server.aiohttp import GraphQLView +from tests.aiohttp.schema import Schema + + +def create_app(schema=Schema, **kwargs): + app = web.Application() + # Only needed to silence aiohttp deprecation warnings + GraphQLView.attach(app, schema=schema, **kwargs) + return app + + +def url_string(**url_params): + base_url = "/graphql" + + if url_params: + return f"{base_url}?{urlencode(url_params)}" + + return base_url diff --git a/tests/aiohttp/schema.py b/tests/aiohttp/schema.py new file mode 100644 index 0000000..9198b12 --- /dev/null +++ b/tests/aiohttp/schema.py @@ -0,0 +1,85 @@ +import asyncio + +from graphql.type.definition import ( + GraphQLArgument, + GraphQLField, + GraphQLNonNull, + GraphQLObjectType, +) +from graphql.type.scalars import GraphQLString +from graphql.type.schema import GraphQLSchema + + +def resolve_raises(*_): + raise Exception("Throws!") + + +# Sync schema +QueryRootType = GraphQLObjectType( + name="QueryRoot", + fields={ + "thrower": GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_raises,), + "request": GraphQLField( + GraphQLNonNull(GraphQLString), + resolve=lambda obj, info, *args: info.context["request"].query.get("q"), + ), + "context": GraphQLField( + GraphQLNonNull(GraphQLString), + resolve=lambda obj, info, *args: info.context["request"], + ), + "test": GraphQLField( + type_=GraphQLString, + args={"who": GraphQLArgument(GraphQLString)}, + resolve=lambda obj, info, who=None: "Hello %s" % (who or "World"), + ), + }, +) + + +MutationRootType = GraphQLObjectType( + name="MutationRoot", + fields={ + "writeTest": GraphQLField( + type_=QueryRootType, resolve=lambda *args: QueryRootType + ) + }, +) + +SubscriptionsRootType = GraphQLObjectType( + name="SubscriptionsRoot", + fields={ + "subscriptionsTest": GraphQLField( + type_=QueryRootType, resolve=lambda *args: QueryRootType + ) + }, +) + +Schema = GraphQLSchema(QueryRootType, MutationRootType, SubscriptionsRootType) + + +# Schema with async methods +async def resolver_field_async_1(_obj, info): + await asyncio.sleep(0.001) + return "hey" + + +async def resolver_field_async_2(_obj, info): + await asyncio.sleep(0.003) + return "hey2" + + +def resolver_field_sync(_obj, info): + return "hey3" + + +AsyncQueryType = GraphQLObjectType( + "AsyncQueryType", + { + "a": GraphQLField(GraphQLString, resolve=resolver_field_async_1), + "b": GraphQLField(GraphQLString, resolve=resolver_field_async_2), + "c": GraphQLField(GraphQLString, resolve=resolver_field_sync), + }, +) + + +AsyncSchema = GraphQLSchema(AsyncQueryType) diff --git a/tests/aiohttp/test_graphiqlview.py b/tests/aiohttp/test_graphiqlview.py new file mode 100644 index 0000000..04a9b50 --- /dev/null +++ b/tests/aiohttp/test_graphiqlview.py @@ -0,0 +1,112 @@ +import pytest +from aiohttp.test_utils import TestClient, TestServer +from jinja2 import Environment + +from tests.aiohttp.app import create_app, url_string +from tests.aiohttp.schema import AsyncSchema, Schema + + +@pytest.fixture +def app(): + app = create_app() + return app + + +@pytest.fixture +async def client(app): + client = TestClient(TestServer(app)) + await client.start_server() + yield client + await client.close() + + +@pytest.fixture +def view_kwargs(): + return { + "schema": Schema, + "graphiql": True, + } + + +@pytest.fixture +def pretty_response(): + return ( + "{\n" + ' "data": {\n' + ' "test": "Hello World"\n' + " }\n" + "}".replace('"', '\\"').replace("\n", "\\n") + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(graphiql=True)]) +async def test_graphiql_is_enabled(app, client): + response = await client.get( + url_string(query="{test}"), headers={"Accept": "text/html"} + ) + assert response.status == 200 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(graphiql=True)]) +async def test_graphiql_simple_renderer(app, client, pretty_response): + response = await client.get( + url_string(query="{test}"), headers={"Accept": "text/html"}, + ) + assert response.status == 200 + assert pretty_response in await response.text() + + +class TestJinjaEnv: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "app", [create_app(graphiql=True, jinja_env=Environment())] + ) + async def test_graphiql_jinja_renderer(self, app, client, pretty_response): + response = await client.get( + url_string(query="{test}"), headers={"Accept": "text/html"}, + ) + assert response.status == 200 + assert pretty_response in await response.text() + + +@pytest.mark.asyncio +async def test_graphiql_html_is_not_accepted(client): + response = await client.get("/graphql", headers={"Accept": "application/json"},) + assert response.status == 400 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(graphiql=True)]) +async def test_graphiql_get_mutation(app, client): + response = await client.get( + url_string(query="mutation TestMutation { writeTest { test } }"), + headers={"Accept": "text/html"}, + ) + assert response.status == 200 + assert "response: null" in await response.text() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(graphiql=True)]) +async def test_graphiql_get_subscriptions(client): + response = await client.get( + url_string( + query="subscription TestSubscriptions { subscriptionsTest { test } }" + ), + headers={"Accept": "text/html"}, + ) + assert response.status == 200 + assert "response: null" in await response.text() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(schema=AsyncSchema, enable_async=True)]) +async def test_graphiql_async_schema(app, client): + response = await client.get( + url_string(query="{a,b,c}"), headers={"Accept": "text/html"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {"a": "hey", "b": "hey2", "c": "hey3"}} diff --git a/tests/aiohttp/test_graphqlview.py b/tests/aiohttp/test_graphqlview.py new file mode 100644 index 0000000..0f6becb --- /dev/null +++ b/tests/aiohttp/test_graphqlview.py @@ -0,0 +1,675 @@ +import json +from urllib.parse import urlencode + +import pytest +from aiohttp import FormData +from aiohttp.test_utils import TestClient, TestServer + +from .app import create_app, url_string +from .schema import AsyncSchema + + +@pytest.fixture +def app(): + app = create_app() + return app + + +@pytest.fixture +async def client(app): + client = TestClient(TestServer(app)) + await client.start_server() + yield client + await client.close() + + +@pytest.mark.asyncio +async def test_allows_get_with_query_param(client): + response = await client.get(url_string(query="{test}")) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello World"}} + + +@pytest.mark.asyncio +async def test_allows_get_with_variable_values(client): + response = await client.get( + url_string( + query="query helloWho($who: String) { test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ) + ) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello Dolly"}} + + +@pytest.mark.asyncio +async def test_allows_get_with_operation_name(client): + response = await client.get( + url_string( + query=""" + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + """, + operationName="helloWorld", + ) + ) + + assert response.status == 200 + assert await response.json() == { + "data": {"test": "Hello World", "shared": "Hello Everyone"} + } + + +@pytest.mark.asyncio +async def test_reports_validation_errors(client): + response = await client.get(url_string(query="{ test, unknownOne, unknownTwo }")) + + assert response.status == 400 + assert await response.json() == { + "errors": [ + { + "message": "Cannot query field 'unknownOne' on type 'QueryRoot'.", + "locations": [{"line": 1, "column": 9}], + "path": None, + }, + { + "message": "Cannot query field 'unknownTwo' on type 'QueryRoot'.", + "locations": [{"line": 1, "column": 21}], + "path": None, + }, + ], + } + + +@pytest.mark.asyncio +async def test_errors_when_missing_operation_name(client): + response = await client.get( + url_string( + query=""" + query TestQuery { test } + mutation TestMutation { writeTest { test } } + subscription TestSubscriptions { subscriptionsTest { test } } + """ + ) + ) + + assert response.status == 400 + assert await response.json() == { + "errors": [ + { + "message": ( + "Must provide operation name if query contains multiple " + "operations." + ), + "locations": None, + "path": None, + }, + ] + } + + +@pytest.mark.asyncio +async def test_errors_when_sending_a_mutation_via_get(client): + response = await client.get( + url_string( + query=""" + mutation TestMutation { writeTest { test } } + """ + ) + ) + assert response.status == 405 + assert await response.json() == { + "errors": [ + { + "message": "Can only perform a mutation operation from a POST request.", + "locations": None, + "path": None, + }, + ], + } + + +@pytest.mark.asyncio +async def test_errors_when_selecting_a_mutation_within_a_get(client): + response = await client.get( + url_string( + query=""" + query TestQuery { test } + mutation TestMutation { writeTest { test } } + """, + operationName="TestMutation", + ) + ) + + assert response.status == 405 + assert await response.json() == { + "errors": [ + { + "message": "Can only perform a mutation operation from a POST request.", + "locations": None, + "path": None, + }, + ], + } + + +@pytest.mark.asyncio +async def test_errors_when_selecting_a_subscription_within_a_get(client): + response = await client.get( + url_string( + query=""" + subscription TestSubscriptions { subscriptionsTest { test } } + """, + operationName="TestSubscriptions", + ) + ) + + assert response.status == 405 + assert await response.json() == { + "errors": [ + { + "message": "Can only perform a subscription operation from a POST " + "request.", + "locations": None, + "path": None, + }, + ], + } + + +@pytest.mark.asyncio +async def test_allows_mutation_to_exist_within_a_get(client): + response = await client.get( + url_string( + query=""" + query TestQuery { test } + mutation TestMutation { writeTest { test } } + """, + operationName="TestQuery", + ) + ) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello World"}} + + +@pytest.mark.asyncio +async def test_allows_post_with_json_encoding(client): + response = await client.post( + "/graphql", + data=json.dumps(dict(query="{test}")), + headers={"content-type": "application/json"}, + ) + + assert await response.json() == {"data": {"test": "Hello World"}} + assert response.status == 200 + + +@pytest.mark.asyncio +async def test_allows_sending_a_mutation_via_post(client): + response = await client.post( + "/graphql", + data=json.dumps(dict(query="mutation TestMutation { writeTest { test } }",)), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {"writeTest": {"test": "Hello World"}}} + + +@pytest.mark.asyncio +async def test_allows_post_with_url_encoding(client): + data = FormData() + data.add_field("query", "{test}") + response = await client.post( + "/graphql", + data=data(), + headers={"content-type": "application/x-www-form-urlencoded"}, + ) + + assert await response.json() == {"data": {"test": "Hello World"}} + assert response.status == 200 + + +@pytest.mark.asyncio +async def test_supports_post_json_query_with_string_variables(client): + response = await client.post( + "/graphql", + data=json.dumps( + dict( + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ) + ), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello Dolly"}} + + +@pytest.mark.asyncio +async def test_supports_post_json_query_with_json_variables(client): + response = await client.post( + "/graphql", + data=json.dumps( + dict( + query="query helloWho($who: String){ test(who: $who) }", + variables={"who": "Dolly"}, + ) + ), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello Dolly"}} + + +@pytest.mark.asyncio +async def test_supports_post_url_encoded_query_with_string_variables(client): + response = await client.post( + "/graphql", + data=urlencode( + dict( + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ), + ), + headers={"content-type": "application/x-www-form-urlencoded"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello Dolly"}} + + +@pytest.mark.asyncio +async def test_supports_post_json_quey_with_get_variable_values(client): + response = await client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + data=json.dumps(dict(query="query helloWho($who: String){ test(who: $who) }",)), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello Dolly"}} + + +@pytest.mark.asyncio +async def test_post_url_encoded_query_with_get_variable_values(client): + response = await client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + data=urlencode(dict(query="query helloWho($who: String){ test(who: $who) }",)), + headers={"content-type": "application/x-www-form-urlencoded"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello Dolly"}} + + +@pytest.mark.asyncio +async def test_supports_post_raw_text_query_with_get_variable_values(client): + response = await client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + data="query helloWho($who: String){ test(who: $who) }", + headers={"content-type": "application/graphql"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {"test": "Hello Dolly"}} + + +@pytest.mark.asyncio +async def test_allows_post_with_operation_name(client): + response = await client.post( + "/graphql", + data=json.dumps( + dict( + query=""" + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + """, + operationName="helloWorld", + ) + ), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == { + "data": {"test": "Hello World", "shared": "Hello Everyone"} + } + + +@pytest.mark.asyncio +async def test_allows_post_with_get_operation_name(client): + response = await client.post( + url_string(operationName="helloWorld"), + data=""" + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + """, + headers={"content-type": "application/graphql"}, + ) + + assert response.status == 200 + assert await response.json() == { + "data": {"test": "Hello World", "shared": "Hello Everyone"} + } + + +@pytest.mark.asyncio +async def test_supports_pretty_printing(client): + response = await client.get(url_string(query="{test}", pretty="1")) + + text = await response.text() + assert text == "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" + + +@pytest.mark.asyncio +async def test_not_pretty_by_default(client): + response = await client.get(url_string(query="{test}")) + + assert await response.text() == '{"data":{"test":"Hello World"}}' + + +@pytest.mark.asyncio +async def test_supports_pretty_printing_by_request(client): + response = await client.get(url_string(query="{test}", pretty="1")) + + assert await response.text() == ( + "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" + ) + + +@pytest.mark.asyncio +async def test_handles_field_errors_caught_by_graphql(client): + response = await client.get(url_string(query="{thrower}")) + assert response.status == 200 + assert await response.json() == { + "data": None, + "errors": [ + { + "locations": [{"column": 2, "line": 1}], + "message": "Throws!", + "path": ["thrower"], + } + ], + } + + +@pytest.mark.asyncio +async def test_handles_syntax_errors_caught_by_graphql(client): + response = await client.get(url_string(query="syntaxerror")) + + assert response.status == 400 + assert await response.json() == { + "errors": [ + { + "locations": [{"column": 1, "line": 1}], + "message": "Syntax Error: Unexpected Name 'syntaxerror'.", + "path": None, + }, + ], + } + + +@pytest.mark.asyncio +async def test_handles_errors_caused_by_a_lack_of_query(client): + response = await client.get("/graphql") + + assert response.status == 400 + assert await response.json() == { + "errors": [ + {"message": "Must provide query string.", "locations": None, "path": None} + ] + } + + +@pytest.mark.asyncio +async def test_handles_batch_correctly_if_is_disabled(client): + response = await client.post( + "/graphql", data="[]", headers={"content-type": "application/json"}, + ) + + assert response.status == 400 + assert await response.json() == { + "errors": [ + { + "message": "Batch GraphQL requests are not enabled.", + "locations": None, + "path": None, + } + ] + } + + +@pytest.mark.asyncio +async def test_handles_incomplete_json_bodies(client): + response = await client.post( + "/graphql", data='{"query":', headers={"content-type": "application/json"}, + ) + + assert response.status == 400 + assert await response.json() == { + "errors": [ + { + "message": "POST body sent invalid JSON.", + "locations": None, + "path": None, + } + ] + } + + +@pytest.mark.asyncio +async def test_handles_plain_post_text(client): + response = await client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + data="query helloWho($who: String){ test(who: $who) }", + headers={"content-type": "text/plain"}, + ) + assert response.status == 400 + assert await response.json() == { + "errors": [ + {"message": "Must provide query string.", "locations": None, "path": None} + ] + } + + +@pytest.mark.asyncio +async def test_handles_poorly_formed_variables(client): + response = await client.get( + url_string( + query="query helloWho($who: String){ test(who: $who) }", variables="who:You" + ), + ) + assert response.status == 400 + assert await response.json() == { + "errors": [ + {"message": "Variables are invalid JSON.", "locations": None, "path": None} + ] + } + + +@pytest.mark.asyncio +async def test_handles_unsupported_http_methods(client): + response = await client.put(url_string(query="{test}")) + assert response.status == 405 + assert response.headers["Allow"] in ["GET, POST", "HEAD, GET, POST, OPTIONS"] + assert await response.json() == { + "errors": [ + { + "message": "GraphQL only supports GET and POST requests.", + "locations": None, + "path": None, + } + ] + } + + +@pytest.mark.parametrize("app", [create_app()]) +@pytest.mark.asyncio +async def test_passes_request_into_request_context(app, client): + response = await client.get(url_string(query="{request}", q="testing")) + + assert response.status == 200 + assert await response.json() == { + "data": {"request": "testing"}, + } + + +class TestCustomContext: + @pytest.mark.parametrize( + "app", [create_app(context="CUSTOM CONTEXT")], + ) + @pytest.mark.asyncio + async def test_context_remapped(self, app, client): + response = await client.get(url_string(query="{context}")) + + _json = await response.json() + assert response.status == 200 + assert "Request" in _json["data"]["context"] + assert "CUSTOM CONTEXT" not in _json["data"]["context"] + + @pytest.mark.parametrize("app", [create_app(context={"request": "test"})]) + @pytest.mark.asyncio + async def test_request_not_replaced(self, app, client): + response = await client.get(url_string(query="{context}")) + + _json = await response.json() + assert response.status == 200 + assert _json["data"]["context"] == "test" + + +@pytest.mark.asyncio +async def test_post_multipart_data(client): + query = "mutation TestMutation { writeTest { test } }" + + data = ( + "------aiohttpgraphql\r\n" + + 'Content-Disposition: form-data; name="query"\r\n' + + "\r\n" + + query + + "\r\n" + + "------aiohttpgraphql--\r\n" + + "Content-Type: text/plain; charset=utf-8\r\n" + + 'Content-Disposition: form-data; name="file"; filename="text1.txt"; filename*=utf-8\'\'text1.txt\r\n' # noqa: ignore + + "\r\n" + + "\r\n" + + "------aiohttpgraphql--\r\n" + ) + + response = await client.post( + "/graphql", + data=data, + headers={"content-type": "multipart/form-data; boundary=----aiohttpgraphql"}, + ) + + assert response.status == 200 + assert await response.json() == {"data": {u"writeTest": {u"test": u"Hello World"}}} + + +class TestBatchExecutor: + @pytest.mark.asyncio + @pytest.mark.parametrize("app", [create_app(batch=True)]) + async def test_batch_allows_post_with_json_encoding(self, app, client): + response = await client.post( + "/graphql", + data=json.dumps([dict(id=1, query="{test}")]), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == [{"data": {"test": "Hello World"}}] + + @pytest.mark.asyncio + @pytest.mark.parametrize("app", [create_app(batch=True)]) + async def test_batch_supports_post_json_query_with_json_variables( + self, app, client + ): + response = await client.post( + "/graphql", + data=json.dumps( + [ + dict( + id=1, + query="query helloWho($who: String){ test(who: $who) }", + variables={"who": "Dolly"}, + ) + ] + ), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == [{"data": {"test": "Hello Dolly"}}] + + @pytest.mark.asyncio + @pytest.mark.parametrize("app", [create_app(batch=True)]) + async def test_batch_allows_post_with_operation_name(self, app, client): + response = await client.post( + "/graphql", + data=json.dumps( + [ + dict( + id=1, + query=""" + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + """, + operationName="helloWorld", + ) + ] + ), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == [ + {"data": {"test": "Hello World", "shared": "Hello Everyone"}} + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(schema=AsyncSchema, enable_async=True)]) +async def test_async_schema(app, client): + response = await client.get(url_string(query="{a,b,c}")) + + assert response.status == 200 + assert await response.json() == {"data": {"a": "hey", "b": "hey2", "c": "hey3"}} + + +@pytest.mark.asyncio +async def test_preflight_request(client): + response = await client.options( + "/graphql", headers={"Access-Control-Request-Method": "POST"}, + ) + + assert response.status == 200 + + +@pytest.mark.asyncio +async def test_preflight_incorrect_request(client): + response = await client.options( + "/graphql", headers={"Access-Control-Request-Method": "OPTIONS"}, + ) + + assert response.status == 400