diff --git a/README.md b/README.md index 6246ea9..3e8f3d3 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ from flask_graphql import GraphQLView app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True)) +# Optional, for adding batch query support (used in Apollo-Client) +app.add_url_rule('/graphql/batch', view_func=GraphQLView.as_view('graphql', schema=schema, batch=True)) ``` This will add `/graphql` and `/graphiql` endpoints to your app. diff --git a/flask_graphql/graphqlview.py b/flask_graphql/graphqlview.py index b227662..cbd8b0f 100644 --- a/flask_graphql/graphqlview.py +++ b/flask_graphql/graphqlview.py @@ -32,6 +32,7 @@ class GraphQLView(View): graphiql_version = None graphiql_template = None middleware = None + batch = False methods = ['GET', 'POST', 'PUT', 'DELETE'] @@ -41,6 +42,7 @@ def __init__(self, **kwargs): if hasattr(self, key): setattr(self, key, value) + assert not all((self.graphiql, self.batch)), 'Use either graphiql or batch processing' assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.' # noinspection PyUnusedLocal @@ -66,33 +68,15 @@ def dispatch_request(self): data = self.parse_body(request) show_graphiql = self.graphiql and self.can_display_graphiql(data) - query, variables, operation_name = self.get_graphql_params(request, data) - - execution_result = self.execute_graphql_request( - data, - query, - variables, - operation_name, - show_graphiql - ) - - if execution_result: - response = {} - - if execution_result.errors: - response['errors'] = [self.format_error(e) for e in execution_result.errors] - - if execution_result.invalid: - status_code = 400 - else: - status_code = 200 - response['data'] = execution_result.data - - result = self.json_encode(request, response) + if self.batch: + responses = [self.get_response(request, entry) for entry in data] + result = '[{}]'.format(','.join([response[0] for response in responses])) + status_code = max(responses, key=lambda response: response[1])[1] else: - result = None + result, status_code = self.get_response(request, data, show_graphiql) if show_graphiql: + query, variables, operation_name, id = self.get_graphql_params(request, data) return render_graphiql( graphiql_version=self.graphiql_version, graphiql_template=self.graphiql_template, @@ -118,6 +102,43 @@ def dispatch_request(self): content_type='application/json' ) + def get_response(self, request, data, show_graphiql=False): + query, variables, operation_name, id = self.get_graphql_params(request, data) + + execution_result = self.execute_graphql_request( + data, + query, + variables, + operation_name, + show_graphiql + ) + + status_code = 200 + if execution_result: + response = {} + + if execution_result.errors: + response['errors'] = [self.format_error(e) for e in execution_result.errors] + + if execution_result.invalid: + status_code = 400 + else: + status_code = 200 + response['data'] = execution_result.data + + if self.batch: + response = { + 'id': id, + 'payload': response, + 'status': status_code, + } + + result = self.json_encode(request, response) + else: + result = None + + return result, status_code + def json_encode(self, request, d): if not self.pretty and not request.args.get('pretty'): return json.dumps(d, separators=(',', ':')) @@ -134,7 +155,10 @@ def parse_body(self, request): elif content_type == 'application/json': try: request_json = json.loads(request.data.decode('utf8')) - assert isinstance(request_json, dict) + if self.batch: + assert isinstance(request_json, list) + else: + assert isinstance(request_json, dict) return request_json except: raise HttpError(BadRequest('POST body sent invalid JSON.')) @@ -207,6 +231,7 @@ def request_wants_html(cls, request): def get_graphql_params(request, data): query = request.args.get('query') or data.get('query') variables = request.args.get('variables') or data.get('variables') + id = request.args.get('id') or data.get('id') if variables and isinstance(variables, six.text_type): try: @@ -216,7 +241,7 @@ def get_graphql_params(request, data): operation_name = request.args.get('operationName') or data.get('operationName') - return query, variables, operation_name + return query, variables, operation_name, id @staticmethod def format_error(error): diff --git a/tests/app.py b/tests/app.py index 13299f1..9f11aee 100644 --- a/tests/app.py +++ b/tests/app.py @@ -3,10 +3,10 @@ from .schema import Schema -def create_app(**kwargs): +def create_app(path='/graphql', **kwargs): app = Flask(__name__) app.debug = True - app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=Schema, **kwargs)) + app.add_url_rule(path, view_func=GraphQLView.as_view('graphql', schema=Schema, **kwargs)) return app diff --git a/tests/test_graphqlview.py b/tests/test_graphqlview.py index d25d4a8..eff759f 100644 --- a/tests/test_graphqlview.py +++ b/tests/test_graphqlview.py @@ -33,6 +33,8 @@ def response_json(response): j = lambda **kwargs: json.dumps(kwargs) +jl = lambda **kwargs: json.dumps([kwargs]) + def test_allows_get_with_query_param(client): response = client.get(url_string(query='{test}')) @@ -453,3 +455,71 @@ def test_post_multipart_data(client): assert response.status_code == 200 assert response_json(response) == {'data': {u'writeTest': {u'test': u'Hello World'}}} + + +@pytest.mark.parametrize('app', [create_app(batch=True)]) +def test_batch_allows_post_with_json_encoding(client): + response = client.post( + url_string(), + data=jl(id=1, query='{test}'), + content_type='application/json' + ) + + assert response.status_code == 200 + assert response_json(response) == [{ + 'id': 1, + 'payload': { 'data': {'test': "Hello World"} }, + 'status': 200, + }] + + +@pytest.mark.parametrize('app', [create_app(batch=True)]) +def test_batch_supports_post_json_query_with_json_variables(client): + response = client.post( + url_string(), + data=jl( + id=1, + query='query helloWho($who: String){ test(who: $who) }', + variables={'who': "Dolly"} + ), + content_type='application/json' + ) + + assert response.status_code == 200 + assert response_json(response) == [{ + 'id': 1, + 'payload': { 'data': {'test': "Hello Dolly"} }, + 'status': 200, + }] + + +@pytest.mark.parametrize('app', [create_app(batch=True)]) +def test_batch_allows_post_with_operation_name(client): + response = client.post( + url_string(), + data=jl( + id=1, + query=''' + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + ''', + operationName='helloWorld' + ), + content_type='application/json' + ) + + assert response.status_code == 200 + assert response_json(response) == [{ + 'id': 1, + 'payload': { + 'data': { + 'test': 'Hello World', + 'shared': 'Hello Everyone' + } + }, + 'status': 200, + }]