diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 3f9dda9f..2f7743da 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -55,6 +55,9 @@ jobs: PYTEST_ADDOPTS: "--color=yes" run: poetry run pytest + - name: Static type check + run: poetry run mypy + - name: Upload coverage uses: codecov/codecov-action@v1 diff --git a/openapi_core/casting/schemas/casters.py b/openapi_core/casting/schemas/casters.py index f6e912b9..14794067 100644 --- a/openapi_core/casting/schemas/casters.py +++ b/openapi_core/casting/schemas/casters.py @@ -1,26 +1,36 @@ +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import List + +from openapi_core.casting.schemas.datatypes import CasterCallable from openapi_core.casting.schemas.exceptions import CastError +from openapi_core.spec import Spec + +if TYPE_CHECKING: + from openapi_core.casting.schemas.factories import SchemaCastersFactory class BaseSchemaCaster: - def __init__(self, schema): + def __init__(self, schema: Spec): self.schema = schema - def __call__(self, value): + def __call__(self, value: Any) -> Any: if value is None: return value return self.cast(value) - def cast(self, value): + def cast(self, value: Any) -> Any: raise NotImplementedError class CallableSchemaCaster(BaseSchemaCaster): - def __init__(self, schema, caster_callable): + def __init__(self, schema: Spec, caster_callable: CasterCallable): super().__init__(schema) self.caster_callable = caster_callable - def cast(self, value): + def cast(self, value: Any) -> Any: try: return self.caster_callable(value) except (ValueError, TypeError): @@ -28,22 +38,22 @@ def cast(self, value): class DummyCaster(BaseSchemaCaster): - def cast(self, value): + def cast(self, value: Any) -> Any: return value class ComplexCaster(BaseSchemaCaster): - def __init__(self, schema, casters_factory): + def __init__(self, schema: Spec, casters_factory: "SchemaCastersFactory"): super().__init__(schema) self.casters_factory = casters_factory class ArrayCaster(ComplexCaster): @property - def items_caster(self): + def items_caster(self) -> BaseSchemaCaster: return self.casters_factory.create(self.schema / "items") - def cast(self, value): + def cast(self, value: Any) -> List[Any]: try: return list(map(self.items_caster, value)) except (ValueError, TypeError): diff --git a/openapi_core/casting/schemas/datatypes.py b/openapi_core/casting/schemas/datatypes.py new file mode 100644 index 00000000..1014bf63 --- /dev/null +++ b/openapi_core/casting/schemas/datatypes.py @@ -0,0 +1,4 @@ +from typing import Any +from typing import Callable + +CasterCallable = Callable[[Any], Any] diff --git a/openapi_core/casting/schemas/exceptions.py b/openapi_core/casting/schemas/exceptions.py index 1f3f8bc4..0c4d25b1 100644 --- a/openapi_core/casting/schemas/exceptions.py +++ b/openapi_core/casting/schemas/exceptions.py @@ -10,5 +10,5 @@ class CastError(OpenAPIError): value: str type: str - def __str__(self): + def __str__(self) -> str: return f"Failed to cast value to {self.type} type: {self.value}" diff --git a/openapi_core/casting/schemas/factories.py b/openapi_core/casting/schemas/factories.py index 3c9b0f21..e0ccfebb 100644 --- a/openapi_core/casting/schemas/factories.py +++ b/openapi_core/casting/schemas/factories.py @@ -1,6 +1,11 @@ +from typing import Dict + from openapi_core.casting.schemas.casters import ArrayCaster +from openapi_core.casting.schemas.casters import BaseSchemaCaster from openapi_core.casting.schemas.casters import CallableSchemaCaster from openapi_core.casting.schemas.casters import DummyCaster +from openapi_core.casting.schemas.datatypes import CasterCallable +from openapi_core.spec import Spec from openapi_core.util import forcebool @@ -11,7 +16,7 @@ class SchemaCastersFactory: "object", "any", ] - PRIMITIVE_CASTERS = { + PRIMITIVE_CASTERS: Dict[str, CasterCallable] = { "integer": int, "number": float, "boolean": forcebool, @@ -20,7 +25,7 @@ class SchemaCastersFactory: "array": ArrayCaster, } - def create(self, schema): + def create(self, schema: Spec) -> BaseSchemaCaster: schema_type = schema.getkey("type", "any") if schema_type in self.DUMMY_CASTERS: diff --git a/openapi_core/contrib/django/handlers.py b/openapi_core/contrib/django/handlers.py index 6d20c340..05bbb742 100644 --- a/openapi_core/contrib/django/handlers.py +++ b/openapi_core/contrib/django/handlers.py @@ -1,5 +1,13 @@ """OpenAPI core contrib django handlers module""" +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Type + from django.http import JsonResponse +from django.http.request import HttpRequest +from django.http.response import HttpResponse from openapi_core.templating.media_types.exceptions import MediaTypeNotFound from openapi_core.templating.paths.exceptions import OperationNotFound @@ -11,7 +19,7 @@ class DjangoOpenAPIErrorsHandler: - OPENAPI_ERROR_STATUS = { + OPENAPI_ERROR_STATUS: Dict[Type[Exception], int] = { MissingRequiredParameter: 400, ServerNotFound: 400, InvalidSecurity: 403, @@ -21,7 +29,12 @@ class DjangoOpenAPIErrorsHandler: } @classmethod - def handle(cls, errors, req, resp=None): + def handle( + cls, + errors: Iterable[Exception], + req: HttpRequest, + resp: Optional[HttpResponse] = None, + ) -> JsonResponse: data_errors = [cls.format_openapi_error(err) for err in errors] data = { "errors": data_errors, @@ -30,7 +43,7 @@ def handle(cls, errors, req, resp=None): return JsonResponse(data, status=data_error_max["status"]) @classmethod - def format_openapi_error(cls, error): + def format_openapi_error(cls, error: Exception) -> Dict[str, Any]: return { "title": str(error), "status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), @@ -38,5 +51,5 @@ def format_openapi_error(cls, error): } @classmethod - def get_error_status(cls, error): - return error["status"] + def get_error_status(cls, error: Dict[str, Any]) -> str: + return str(error["status"]) diff --git a/openapi_core/contrib/django/middlewares.py b/openapi_core/contrib/django/middlewares.py index 08de5f71..570b7632 100644 --- a/openapi_core/contrib/django/middlewares.py +++ b/openapi_core/contrib/django/middlewares.py @@ -1,13 +1,22 @@ """OpenAPI core contrib django middlewares module""" +from typing import Callable + from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.http import JsonResponse +from django.http.request import HttpRequest +from django.http.response import HttpResponse from openapi_core.contrib.django.handlers import DjangoOpenAPIErrorsHandler from openapi_core.contrib.django.requests import DjangoOpenAPIRequest from openapi_core.contrib.django.responses import DjangoOpenAPIResponse from openapi_core.validation.processors import OpenAPIProcessor from openapi_core.validation.request import openapi_request_validator +from openapi_core.validation.request.datatypes import RequestValidationResult +from openapi_core.validation.request.protocols import Request from openapi_core.validation.response import openapi_response_validator +from openapi_core.validation.response.datatypes import ResponseValidationResult +from openapi_core.validation.response.protocols import Response class DjangoOpenAPIMiddleware: @@ -16,7 +25,7 @@ class DjangoOpenAPIMiddleware: response_class = DjangoOpenAPIResponse errors_handler = DjangoOpenAPIErrorsHandler() - def __init__(self, get_response): + def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): self.get_response = get_response if not hasattr(settings, "OPENAPI_SPEC"): @@ -26,7 +35,7 @@ def __init__(self, get_response): openapi_request_validator, openapi_response_validator ) - def __call__(self, request): + def __call__(self, request: HttpRequest) -> HttpResponse: openapi_request = self._get_openapi_request(request) req_result = self.validation_processor.process_request( settings.OPENAPI_SPEC, openapi_request @@ -46,14 +55,25 @@ def __call__(self, request): return response - def _handle_request_errors(self, request_result, req): + def _handle_request_errors( + self, request_result: RequestValidationResult, req: HttpRequest + ) -> JsonResponse: return self.errors_handler.handle(request_result.errors, req, None) - def _handle_response_errors(self, response_result, req, resp): + def _handle_response_errors( + self, + response_result: ResponseValidationResult, + req: HttpRequest, + resp: HttpResponse, + ) -> JsonResponse: return self.errors_handler.handle(response_result.errors, req, resp) - def _get_openapi_request(self, request): + def _get_openapi_request( + self, request: HttpRequest + ) -> DjangoOpenAPIRequest: return self.request_class(request) - def _get_openapi_response(self, response): + def _get_openapi_response( + self, response: HttpResponse + ) -> DjangoOpenAPIResponse: return self.response_class(response) diff --git a/openapi_core/contrib/django/requests.py b/openapi_core/contrib/django/requests.py index be5bed87..b894063b 100644 --- a/openapi_core/contrib/django/requests.py +++ b/openapi_core/contrib/django/requests.py @@ -1,7 +1,8 @@ """OpenAPI core contrib django requests module""" import re -from urllib.parse import urljoin +from typing import Optional +from django.http.request import HttpRequest from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict @@ -24,28 +25,33 @@ class DjangoOpenAPIRequest: path_regex = re.compile(PATH_PARAMETER_PATTERN) - def __init__(self, request): + def __init__(self, request: HttpRequest): self.request = request - self.parameters = RequestParameters( - path=self.request.resolver_match + path = ( + self.request.resolver_match and self.request.resolver_match.kwargs - or {}, + or {} + ) + self.parameters = RequestParameters( + path=path, query=ImmutableMultiDict(self.request.GET), header=Headers(self.request.headers.items()), cookie=ImmutableMultiDict(dict(self.request.COOKIES)), ) @property - def host_url(self): + def host_url(self) -> str: + assert isinstance(self.request._current_scheme_host, str) return self.request._current_scheme_host @property - def path(self): + def path(self) -> str: + assert isinstance(self.request.path, str) return self.request.path @property - def path_pattern(self): + def path_pattern(self) -> Optional[str]: if self.request.resolver_match is None: return None @@ -58,13 +64,17 @@ def path_pattern(self): return "/" + route @property - def method(self): + def method(self) -> str: + if self.request.method is None: + return "" + assert isinstance(self.request.method, str) return self.request.method.lower() @property - def body(self): - return self.request.body + def body(self) -> str: + assert isinstance(self.request.body, bytes) + return self.request.body.decode("utf-8") @property - def mimetype(self): - return self.request.content_type + def mimetype(self) -> str: + return self.request.content_type or "" diff --git a/openapi_core/contrib/django/responses.py b/openapi_core/contrib/django/responses.py index 212fad2e..838eff06 100644 --- a/openapi_core/contrib/django/responses.py +++ b/openapi_core/contrib/django/responses.py @@ -1,23 +1,28 @@ """OpenAPI core contrib django responses module""" +from django.http.response import HttpResponse from werkzeug.datastructures import Headers class DjangoOpenAPIResponse: - def __init__(self, response): + def __init__(self, response: HttpResponse): self.response = response @property - def data(self): - return self.response.content + def data(self) -> str: + assert isinstance(self.response.content, bytes) + return self.response.content.decode("utf-8") @property - def status_code(self): + def status_code(self) -> int: + assert isinstance(self.response.status_code, int) return self.response.status_code @property - def headers(self): + def headers(self) -> Headers: return Headers(self.response.headers.items()) @property - def mimetype(self): - return self.response["Content-Type"] + def mimetype(self) -> str: + content_type = self.response.get("Content-Type", "") + assert isinstance(content_type, str) + return content_type diff --git a/openapi_core/contrib/falcon/handlers.py b/openapi_core/contrib/falcon/handlers.py index 77d2e63f..6bd59f25 100644 --- a/openapi_core/contrib/falcon/handlers.py +++ b/openapi_core/contrib/falcon/handlers.py @@ -1,8 +1,14 @@ """OpenAPI core contrib falcon handlers module""" from json import dumps +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Type from falcon import status_codes from falcon.constants import MEDIA_JSON +from falcon.request import Request +from falcon.response import Response from openapi_core.templating.media_types.exceptions import MediaTypeNotFound from openapi_core.templating.paths.exceptions import OperationNotFound @@ -14,7 +20,7 @@ class FalconOpenAPIErrorsHandler: - OPENAPI_ERROR_STATUS = { + OPENAPI_ERROR_STATUS: Dict[Type[Exception], int] = { MissingRequiredParameter: 400, ServerNotFound: 400, InvalidSecurity: 403, @@ -24,7 +30,9 @@ class FalconOpenAPIErrorsHandler: } @classmethod - def handle(cls, req, resp, errors): + def handle( + cls, req: Request, resp: Response, errors: Iterable[Exception] + ) -> None: data_errors = [cls.format_openapi_error(err) for err in errors] data = { "errors": data_errors, @@ -41,7 +49,7 @@ def handle(cls, req, resp, errors): resp.complete = True @classmethod - def format_openapi_error(cls, error): + def format_openapi_error(cls, error: Exception) -> Dict[str, Any]: return { "title": str(error), "status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), @@ -49,5 +57,5 @@ def format_openapi_error(cls, error): } @classmethod - def get_error_status(cls, error): - return error["status"] + def get_error_status(cls, error: Dict[str, Any]) -> int: + return int(error["status"]) diff --git a/openapi_core/contrib/falcon/middlewares.py b/openapi_core/contrib/falcon/middlewares.py index eac38a24..c2d509f7 100644 --- a/openapi_core/contrib/falcon/middlewares.py +++ b/openapi_core/contrib/falcon/middlewares.py @@ -1,11 +1,20 @@ """OpenAPI core contrib falcon middlewares module""" +from typing import Any +from typing import Optional +from typing import Type + +from falcon.request import Request +from falcon.response import Response from openapi_core.contrib.falcon.handlers import FalconOpenAPIErrorsHandler from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse +from openapi_core.spec import Spec from openapi_core.validation.processors import OpenAPIProcessor from openapi_core.validation.request import openapi_request_validator +from openapi_core.validation.request.datatypes import RequestValidationResult from openapi_core.validation.response import openapi_response_validator +from openapi_core.validation.response.datatypes import ResponseValidationResult class FalconOpenAPIMiddleware: @@ -16,11 +25,11 @@ class FalconOpenAPIMiddleware: def __init__( self, - spec, - validation_processor, - request_class=None, - response_class=None, - errors_handler=None, + spec: Spec, + validation_processor: OpenAPIProcessor, + request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, + response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, + errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, ): self.spec = spec self.validation_processor = validation_processor @@ -31,11 +40,11 @@ def __init__( @classmethod def from_spec( cls, - spec, - request_class=None, - response_class=None, - errors_handler=None, - ): + spec: Spec, + request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, + response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, + errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, + ) -> "FalconOpenAPIMiddleware": validation_processor = OpenAPIProcessor( openapi_request_validator, openapi_response_validator ) @@ -47,13 +56,15 @@ def from_spec( errors_handler=errors_handler, ) - def process_request(self, req, resp): + def process_request(self, req: Request, resp: Response) -> None: openapi_req = self._get_openapi_request(req) req.context.openapi = self._process_openapi_request(openapi_req) if req.context.openapi.errors: return self._handle_request_errors(req, resp, req.context.openapi) - def process_response(self, req, resp, resource, req_succeeded): + def process_response( + self, req: Request, resp: Response, resource: Any, req_succeeded: bool + ) -> None: openapi_req = self._get_openapi_request(req) openapi_resp = self._get_openapi_response(resp) resp.context.openapi = self._process_openapi_response( @@ -64,24 +75,42 @@ def process_response(self, req, resp, resource, req_succeeded): req, resp, resp.context.openapi ) - def _handle_request_errors(self, req, resp, request_result): + def _handle_request_errors( + self, + req: Request, + resp: Response, + request_result: RequestValidationResult, + ) -> None: return self.errors_handler.handle(req, resp, request_result.errors) - def _handle_response_errors(self, req, resp, response_result): + def _handle_response_errors( + self, + req: Request, + resp: Response, + response_result: ResponseValidationResult, + ) -> None: return self.errors_handler.handle(req, resp, response_result.errors) - def _get_openapi_request(self, request): + def _get_openapi_request(self, request: Request) -> FalconOpenAPIRequest: return self.request_class(request) - def _get_openapi_response(self, response): + def _get_openapi_response( + self, response: Response + ) -> FalconOpenAPIResponse: return self.response_class(response) - def _process_openapi_request(self, openapi_request): + def _process_openapi_request( + self, openapi_request: FalconOpenAPIRequest + ) -> RequestValidationResult: return self.validation_processor.process_request( self.spec, openapi_request ) - def _process_openapi_response(self, opneapi_request, openapi_response): + def _process_openapi_response( + self, + opneapi_request: FalconOpenAPIRequest, + openapi_response: FalconOpenAPIResponse, + ) -> ResponseValidationResult: return self.validation_processor.process_response( self.spec, opneapi_request, openapi_response ) diff --git a/openapi_core/contrib/falcon/requests.py b/openapi_core/contrib/falcon/requests.py index 28833c95..c078e8bf 100644 --- a/openapi_core/contrib/falcon/requests.py +++ b/openapi_core/contrib/falcon/requests.py @@ -1,6 +1,11 @@ """OpenAPI core contrib falcon responses module""" from json import dumps +from typing import Any +from typing import Dict +from typing import Optional +from falcon.request import Request +from falcon.request import RequestOptions from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict @@ -8,7 +13,11 @@ class FalconOpenAPIRequest: - def __init__(self, request, default_when_empty=None): + def __init__( + self, + request: Request, + default_when_empty: Optional[Dict[Any, Any]] = None, + ): self.request = request if default_when_empty is None: default_when_empty = {} @@ -22,19 +31,22 @@ def __init__(self, request, default_when_empty=None): ) @property - def host_url(self): + def host_url(self) -> str: + assert isinstance(self.request.prefix, str) return self.request.prefix @property - def path(self): + def path(self) -> str: + assert isinstance(self.request.path, str) return self.request.path @property - def method(self): + def method(self) -> str: + assert isinstance(self.request.method, str) return self.request.method.lower() @property - def body(self): + def body(self) -> Optional[str]: media = self.request.get_media( default_when_empty=self.default_when_empty ) @@ -42,7 +54,11 @@ def body(self): return dumps(getattr(self.request, "json", media)) @property - def mimetype(self): + def mimetype(self) -> str: if self.request.content_type: + assert isinstance(self.request.content_type, str) return self.request.content_type.partition(";")[0] + + assert isinstance(self.request.options, RequestOptions) + assert isinstance(self.request.options.default_media_type, str) return self.request.options.default_media_type diff --git a/openapi_core/contrib/falcon/responses.py b/openapi_core/contrib/falcon/responses.py index 18374b80..efeb6d3c 100644 --- a/openapi_core/contrib/falcon/responses.py +++ b/openapi_core/contrib/falcon/responses.py @@ -1,21 +1,23 @@ """OpenAPI core contrib falcon responses module""" +from falcon.response import Response from werkzeug.datastructures import Headers class FalconOpenAPIResponse: - def __init__(self, response): + def __init__(self, response: Response): self.response = response @property - def data(self): + def data(self) -> str: + assert isinstance(self.response.text, str) return self.response.text @property - def status_code(self): + def status_code(self) -> int: return int(self.response.status[:3]) @property - def mimetype(self): + def mimetype(self) -> str: mimetype = "" if self.response.content_type: mimetype = self.response.content_type.partition(";")[0] @@ -24,5 +26,5 @@ def mimetype(self): return mimetype @property - def headers(self): + def headers(self) -> Headers: return Headers(self.response.headers) diff --git a/openapi_core/contrib/flask/decorators.py b/openapi_core/contrib/flask/decorators.py index 45025808..b30f41d8 100644 --- a/openapi_core/contrib/flask/decorators.py +++ b/openapi_core/contrib/flask/decorators.py @@ -1,50 +1,111 @@ """OpenAPI core contrib flask decorators module""" +from functools import wraps +from typing import Any +from typing import Callable +from typing import Type + +from flask.globals import request +from flask.wrappers import Request +from flask.wrappers import Response + from openapi_core.contrib.flask.handlers import FlaskOpenAPIErrorsHandler from openapi_core.contrib.flask.providers import FlaskRequestProvider from openapi_core.contrib.flask.requests import FlaskOpenAPIRequest from openapi_core.contrib.flask.responses import FlaskOpenAPIResponse -from openapi_core.validation.decorators import OpenAPIDecorator +from openapi_core.spec import Spec +from openapi_core.validation.processors import OpenAPIProcessor from openapi_core.validation.request import openapi_request_validator +from openapi_core.validation.request.datatypes import RequestValidationResult +from openapi_core.validation.request.validators import RequestValidator from openapi_core.validation.response import openapi_response_validator +from openapi_core.validation.response.datatypes import ResponseValidationResult +from openapi_core.validation.response.validators import ResponseValidator -class FlaskOpenAPIViewDecorator(OpenAPIDecorator): +class FlaskOpenAPIViewDecorator(OpenAPIProcessor): def __init__( self, - spec, - request_validator, - response_validator, - request_class=FlaskOpenAPIRequest, - response_class=FlaskOpenAPIResponse, - request_provider=FlaskRequestProvider, - openapi_errors_handler=FlaskOpenAPIErrorsHandler, + spec: Spec, + request_validator: RequestValidator, + response_validator: ResponseValidator, + request_class: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, + response_class: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, + request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, + openapi_errors_handler: Type[ + FlaskOpenAPIErrorsHandler + ] = FlaskOpenAPIErrorsHandler, ): - super().__init__( - spec, - request_validator, - response_validator, - request_class, - response_class, - request_provider, - openapi_errors_handler, - ) + super().__init__(request_validator, response_validator) + self.spec = spec + self.request_class = request_class + self.response_class = response_class + self.request_provider = request_provider + self.openapi_errors_handler = openapi_errors_handler - def _handle_request_view(self, request_result, view, *args, **kwargs): - request = self._get_request(*args, **kwargs) - request.openapi = request_result - return super()._handle_request_view( - request_result, view, *args, **kwargs - ) + def __call__(self, view: Callable[..., Any]) -> Callable[..., Any]: + @wraps(view) + def decorated(*args: Any, **kwargs: Any) -> Response: + request = self._get_request() + openapi_request = self._get_openapi_request(request) + request_result = self.process_request(self.spec, openapi_request) + if request_result.errors: + return self._handle_request_errors(request_result) + response = self._handle_request_view( + request_result, view, *args, **kwargs + ) + openapi_response = self._get_openapi_response(response) + response_result = self.process_response( + self.spec, openapi_request, openapi_response + ) + if response_result.errors: + return self._handle_response_errors(response_result) + return response + + return decorated + + def _handle_request_view( + self, + request_result: RequestValidationResult, + view: Callable[[Any], Response], + *args: Any, + **kwargs: Any + ) -> Response: + request = self._get_request() + request.openapi = request_result # type: ignore + return view(*args, **kwargs) + + def _handle_request_errors( + self, request_result: RequestValidationResult + ) -> Response: + return self.openapi_errors_handler.handle(request_result.errors) + + def _handle_response_errors( + self, response_result: ResponseValidationResult + ) -> Response: + return self.openapi_errors_handler.handle(response_result.errors) + + def _get_request(self) -> Request: + return request + + def _get_openapi_request(self, request: Request) -> FlaskOpenAPIRequest: + return self.request_class(request) + + def _get_openapi_response( + self, response: Response + ) -> FlaskOpenAPIResponse: + return self.response_class(response) @classmethod def from_spec( cls, - spec, - request_class=FlaskOpenAPIRequest, - response_class=FlaskOpenAPIResponse, - request_provider=FlaskRequestProvider, - openapi_errors_handler=FlaskOpenAPIErrorsHandler, - ): + spec: Spec, + request_class: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, + response_class: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, + request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, + openapi_errors_handler: Type[ + FlaskOpenAPIErrorsHandler + ] = FlaskOpenAPIErrorsHandler, + ) -> "FlaskOpenAPIViewDecorator": return cls( spec, request_validator=openapi_request_validator, diff --git a/openapi_core/contrib/flask/handlers.py b/openapi_core/contrib/flask/handlers.py index 1f15d2be..02befc3f 100644 --- a/openapi_core/contrib/flask/handlers.py +++ b/openapi_core/contrib/flask/handlers.py @@ -1,6 +1,12 @@ """OpenAPI core contrib flask handlers module""" +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Type + from flask.globals import current_app from flask.json import dumps +from flask.wrappers import Response from openapi_core.templating.media_types.exceptions import MediaTypeNotFound from openapi_core.templating.paths.exceptions import OperationNotFound @@ -10,7 +16,7 @@ class FlaskOpenAPIErrorsHandler: - OPENAPI_ERROR_STATUS = { + OPENAPI_ERROR_STATUS: Dict[Type[Exception], int] = { ServerNotFound: 400, OperationNotFound: 405, PathNotFound: 404, @@ -18,7 +24,7 @@ class FlaskOpenAPIErrorsHandler: } @classmethod - def handle(cls, errors): + def handle(cls, errors: Iterable[Exception]) -> Response: data_errors = [cls.format_openapi_error(err) for err in errors] data = { "errors": data_errors, @@ -30,7 +36,7 @@ def handle(cls, errors): ) @classmethod - def format_openapi_error(cls, error): + def format_openapi_error(cls, error: Exception) -> Dict[str, Any]: return { "title": str(error), "status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), @@ -38,5 +44,5 @@ def format_openapi_error(cls, error): } @classmethod - def get_error_status(cls, error): - return error["status"] + def get_error_status(cls, error: Dict[str, Any]) -> int: + return int(error["status"]) diff --git a/openapi_core/contrib/flask/providers.py b/openapi_core/contrib/flask/providers.py index f45784ad..47729d25 100644 --- a/openapi_core/contrib/flask/providers.py +++ b/openapi_core/contrib/flask/providers.py @@ -1,8 +1,11 @@ """OpenAPI core contrib flask providers module""" +from typing import Any + from flask.globals import request +from flask.wrappers import Request class FlaskRequestProvider: @classmethod - def provide(self, *args, **kwargs): + def provide(self, *args: Any, **kwargs: Any) -> Request: return request diff --git a/openapi_core/contrib/flask/requests.py b/openapi_core/contrib/flask/requests.py index b211bf66..7e04447e 100644 --- a/openapi_core/contrib/flask/requests.py +++ b/openapi_core/contrib/flask/requests.py @@ -1,7 +1,10 @@ """OpenAPI core contrib flask requests module""" import re +from typing import Optional +from flask.wrappers import Request from werkzeug.datastructures import Headers +from werkzeug.datastructures import ImmutableMultiDict from openapi_core.validation.request.datatypes import RequestParameters @@ -13,39 +16,39 @@ class FlaskOpenAPIRequest: path_regex = re.compile(PATH_PARAMETER_PATTERN) - def __init__(self, request): + def __init__(self, request: Request): self.request = request self.parameters = RequestParameters( - path=self.request.view_args, - query=self.request.args, + path=self.request.view_args or {}, + query=ImmutableMultiDict(self.request.args), header=Headers(self.request.headers), cookie=self.request.cookies, ) @property - def host_url(self): + def host_url(self) -> str: return self.request.host_url @property - def path(self): + def path(self) -> str: return self.request.path @property - def path_pattern(self): + def path_pattern(self) -> str: if self.request.url_rule is None: return self.request.path else: return self.path_regex.sub(r"{\1}", self.request.url_rule.rule) @property - def method(self): + def method(self) -> str: return self.request.method.lower() @property - def body(self): - return self.request.data + def body(self) -> Optional[str]: + return self.request.get_data(as_text=True) @property - def mimetype(self): + def mimetype(self) -> str: return self.request.mimetype diff --git a/openapi_core/contrib/flask/responses.py b/openapi_core/contrib/flask/responses.py index 4ea37137..27a03005 100644 --- a/openapi_core/contrib/flask/responses.py +++ b/openapi_core/contrib/flask/responses.py @@ -1,23 +1,24 @@ """OpenAPI core contrib flask responses module""" +from flask.wrappers import Response from werkzeug.datastructures import Headers class FlaskOpenAPIResponse: - def __init__(self, response): + def __init__(self, response: Response): self.response = response @property - def data(self): - return self.response.data + def data(self) -> str: + return self.response.get_data(as_text=True) @property - def status_code(self): + def status_code(self) -> int: return self.response._status_code @property - def mimetype(self): - return self.response.mimetype + def mimetype(self) -> str: + return str(self.response.mimetype) @property - def headers(self): + def headers(self) -> Headers: return Headers(self.response.headers) diff --git a/openapi_core/contrib/flask/views.py b/openapi_core/contrib/flask/views.py index 5bb58778..499a37ba 100644 --- a/openapi_core/contrib/flask/views.py +++ b/openapi_core/contrib/flask/views.py @@ -1,8 +1,11 @@ """OpenAPI core contrib flask views module""" +from typing import Any + from flask.views import MethodView from openapi_core.contrib.flask.decorators import FlaskOpenAPIViewDecorator from openapi_core.contrib.flask.handlers import FlaskOpenAPIErrorsHandler +from openapi_core.spec import Spec from openapi_core.validation.request import openapi_request_validator from openapi_core.validation.response import openapi_response_validator @@ -12,11 +15,11 @@ class FlaskOpenAPIView(MethodView): openapi_errors_handler = FlaskOpenAPIErrorsHandler - def __init__(self, spec): + def __init__(self, spec: Spec): super().__init__() self.spec = spec - def dispatch_request(self, *args, **kwargs): + def dispatch_request(self, *args: Any, **kwargs: Any) -> Any: decorator = FlaskOpenAPIViewDecorator( self.spec, request_validator=openapi_request_validator, diff --git a/openapi_core/contrib/requests/protocols.py b/openapi_core/contrib/requests/protocols.py new file mode 100644 index 00000000..043c5a28 --- /dev/null +++ b/openapi_core/contrib/requests/protocols.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Protocol + from typing_extensions import runtime_checkable +else: + try: + from typing import Protocol + from typing import runtime_checkable + except ImportError: + from typing_extensions import Protocol + from typing_extensions import runtime_checkable + +from requests.cookies import RequestsCookieJar + + +@runtime_checkable +class SupportsCookieJar(Protocol): + _cookies: RequestsCookieJar diff --git a/openapi_core/contrib/requests/requests.py b/openapi_core/contrib/requests/requests.py index af62a79a..57a9eafd 100644 --- a/openapi_core/contrib/requests/requests.py +++ b/openapi_core/contrib/requests/requests.py @@ -1,12 +1,16 @@ """OpenAPI core contrib requests requests module""" - +from typing import Optional +from typing import Union from urllib.parse import parse_qs from urllib.parse import urlparse +from requests import PreparedRequest from requests import Request +from requests.cookies import RequestsCookieJar from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict +from openapi_core.contrib.requests.protocols import SupportsCookieJar from openapi_core.validation.request.datatypes import RequestParameters @@ -18,45 +22,57 @@ class RequestsOpenAPIRequest: payload being sent """ - def __init__(self, request): + def __init__(self, request: Union[Request, PreparedRequest]): if isinstance(request, Request): request = request.prepare() self.request = request + if request.url is None: + raise RuntimeError("Request URL is missing") self._url_parsed = urlparse(request.url) cookie = {} - if self.request._cookies is not None: + if isinstance(self.request, SupportsCookieJar) and isinstance( + self.request._cookies, RequestsCookieJar + ): # cookies are stored in a cookiejar object cookie = self.request._cookies.get_dict() self.parameters = RequestParameters( query=ImmutableMultiDict(parse_qs(self._url_parsed.query)), header=Headers(dict(self.request.headers)), - cookie=cookie, + cookie=ImmutableMultiDict(cookie), ) @property - def host_url(self): + def host_url(self) -> str: return f"{self._url_parsed.scheme}://{self._url_parsed.netloc}" @property - def path(self): + def path(self) -> str: + assert isinstance(self._url_parsed.path, str) return self._url_parsed.path @property - def method(self): - return self.request.method.lower() + def method(self) -> str: + method = self.request.method + return method and method.lower() or "" @property - def body(self): + def body(self) -> Optional[str]: + if self.request.body is None: + return None + if isinstance(self.request.body, bytes): + return self.request.body.decode("utf-8") + assert isinstance(self.request.body, str) # TODO: figure out if request._body_position is relevant return self.request.body @property - def mimetype(self): + def mimetype(self) -> str: # Order matters because all python requests issued from a session # include Accept */* which does not necessarily match the content type - return self.request.headers.get( - "Content-Type" - ) or self.request.headers.get("Accept") + return str( + self.request.headers.get("Content-Type") + or self.request.headers.get("Accept") + ) diff --git a/openapi_core/contrib/requests/responses.py b/openapi_core/contrib/requests/responses.py index 05d68d6d..149012af 100644 --- a/openapi_core/contrib/requests/responses.py +++ b/openapi_core/contrib/requests/responses.py @@ -1,23 +1,25 @@ """OpenAPI core contrib requests responses module""" +from requests import Response from werkzeug.datastructures import Headers class RequestsOpenAPIResponse: - def __init__(self, response): + def __init__(self, response: Response): self.response = response @property - def data(self): - return self.response.content + def data(self) -> str: + assert isinstance(self.response.content, bytes) + return self.response.content.decode("utf-8") @property - def status_code(self): - return self.response.status_code + def status_code(self) -> int: + return int(self.response.status_code) @property - def mimetype(self): - return self.response.headers.get("Content-Type") + def mimetype(self) -> str: + return str(self.response.headers.get("Content-Type", "")) @property - def headers(self): + def headers(self) -> Headers: return Headers(dict(self.response.headers)) diff --git a/openapi_core/deserializing/media_types/datatypes.py b/openapi_core/deserializing/media_types/datatypes.py new file mode 100644 index 00000000..3d45ab69 --- /dev/null +++ b/openapi_core/deserializing/media_types/datatypes.py @@ -0,0 +1,4 @@ +from typing import Any +from typing import Callable + +DeserializerCallable = Callable[[Any], Any] diff --git a/openapi_core/deserializing/media_types/deserializers.py b/openapi_core/deserializing/media_types/deserializers.py index 2d62cfcd..bac900d4 100644 --- a/openapi_core/deserializing/media_types/deserializers.py +++ b/openapi_core/deserializing/media_types/deserializers.py @@ -1,30 +1,37 @@ import warnings +from typing import Any +from typing import Callable +from openapi_core.deserializing.media_types.datatypes import ( + DeserializerCallable, +) from openapi_core.deserializing.media_types.exceptions import ( MediaTypeDeserializeError, ) class BaseMediaTypeDeserializer: - def __init__(self, mimetype): + def __init__(self, mimetype: str): self.mimetype = mimetype - def __call__(self, value): + def __call__(self, value: Any) -> Any: raise NotImplementedError class UnsupportedMimetypeDeserializer(BaseMediaTypeDeserializer): - def __call__(self, value): + def __call__(self, value: Any) -> Any: warnings.warn(f"Unsupported {self.mimetype} mimetype") return value class CallableMediaTypeDeserializer(BaseMediaTypeDeserializer): - def __init__(self, mimetype, deserializer_callable): + def __init__( + self, mimetype: str, deserializer_callable: DeserializerCallable + ): self.mimetype = mimetype self.deserializer_callable = deserializer_callable - def __call__(self, value): + def __call__(self, value: Any) -> Any: try: return self.deserializer_callable(value) except (ValueError, TypeError, AttributeError): diff --git a/openapi_core/deserializing/media_types/exceptions.py b/openapi_core/deserializing/media_types/exceptions.py index 87def336..66dd904d 100644 --- a/openapi_core/deserializing/media_types/exceptions.py +++ b/openapi_core/deserializing/media_types/exceptions.py @@ -10,7 +10,7 @@ class MediaTypeDeserializeError(DeserializeError): mimetype: str value: str - def __str__(self): + def __str__(self) -> str: return ( "Failed to deserialize value with {mimetype} mimetype: {value}" ).format(value=self.value, mimetype=self.mimetype) diff --git a/openapi_core/deserializing/media_types/factories.py b/openapi_core/deserializing/media_types/factories.py index 3b0aa547..208976fd 100644 --- a/openapi_core/deserializing/media_types/factories.py +++ b/openapi_core/deserializing/media_types/factories.py @@ -1,5 +1,15 @@ from json import loads +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from openapi_core.deserializing.media_types.datatypes import ( + DeserializerCallable, +) +from openapi_core.deserializing.media_types.deserializers import ( + BaseMediaTypeDeserializer, +) from openapi_core.deserializing.media_types.deserializers import ( CallableMediaTypeDeserializer, ) @@ -12,18 +22,21 @@ class MediaTypeDeserializersFactory: - MEDIA_TYPE_DESERIALIZERS = { + MEDIA_TYPE_DESERIALIZERS: Dict[str, DeserializerCallable] = { "application/json": loads, "application/x-www-form-urlencoded": urlencoded_form_loads, "multipart/form-data": data_form_loads, } - def __init__(self, custom_deserializers=None): + def __init__( + self, + custom_deserializers: Optional[Dict[str, DeserializerCallable]] = None, + ): if custom_deserializers is None: custom_deserializers = {} self.custom_deserializers = custom_deserializers - def create(self, mimetype): + def create(self, mimetype: str) -> BaseMediaTypeDeserializer: deserialize_callable = self.get_deserializer_callable(mimetype) if deserialize_callable is None: @@ -31,7 +44,9 @@ def create(self, mimetype): return CallableMediaTypeDeserializer(mimetype, deserialize_callable) - def get_deserializer_callable(self, mimetype): + def get_deserializer_callable( + self, mimetype: str + ) -> Optional[DeserializerCallable]: if mimetype in self.custom_deserializers: return self.custom_deserializers[mimetype] return self.MEDIA_TYPE_DESERIALIZERS.get(mimetype) diff --git a/openapi_core/deserializing/media_types/util.py b/openapi_core/deserializing/media_types/util.py index 22d9f345..4179cad0 100644 --- a/openapi_core/deserializing/media_types/util.py +++ b/openapi_core/deserializing/media_types/util.py @@ -1,13 +1,16 @@ from email.parser import Parser +from typing import Any +from typing import Dict +from typing import Union from urllib.parse import parse_qsl -def urlencoded_form_loads(value): +def urlencoded_form_loads(value: Any) -> Dict[str, Any]: return dict(parse_qsl(value)) -def data_form_loads(value): - if issubclass(type(value), bytes): +def data_form_loads(value: Union[str, bytes]) -> Dict[str, Any]: + if isinstance(value, bytes): value = value.decode("ASCII", errors="surrogateescape") parser = Parser() parts = parser.parsestr(value, headersonly=False) diff --git a/openapi_core/deserializing/parameters/datatypes.py b/openapi_core/deserializing/parameters/datatypes.py new file mode 100644 index 00000000..f2a47c29 --- /dev/null +++ b/openapi_core/deserializing/parameters/datatypes.py @@ -0,0 +1,4 @@ +from typing import Callable +from typing import List + +DeserializerCallable = Callable[[str], List[str]] diff --git a/openapi_core/deserializing/parameters/deserializers.py b/openapi_core/deserializing/parameters/deserializers.py index 9565d02d..22906c0e 100644 --- a/openapi_core/deserializing/parameters/deserializers.py +++ b/openapi_core/deserializing/parameters/deserializers.py @@ -1,37 +1,49 @@ import warnings +from typing import Any +from typing import Callable +from typing import List from openapi_core.deserializing.exceptions import DeserializeError +from openapi_core.deserializing.parameters.datatypes import ( + DeserializerCallable, +) from openapi_core.deserializing.parameters.exceptions import ( EmptyQueryParameterValue, ) from openapi_core.schema.parameters import get_aslist from openapi_core.schema.parameters import get_explode +from openapi_core.spec import Spec class BaseParameterDeserializer: - def __init__(self, param_or_header, style): + def __init__(self, param_or_header: Spec, style: str): self.param_or_header = param_or_header self.style = style - def __call__(self, value): + def __call__(self, value: Any) -> Any: raise NotImplementedError class UnsupportedStyleDeserializer(BaseParameterDeserializer): - def __call__(self, value): + def __call__(self, value: Any) -> Any: warnings.warn(f"Unsupported {self.style} style") return value class CallableParameterDeserializer(BaseParameterDeserializer): - def __init__(self, param_or_header, style, deserializer_callable): + def __init__( + self, + param_or_header: Spec, + style: str, + deserializer_callable: DeserializerCallable, + ): super().__init__(param_or_header, style) self.deserializer_callable = deserializer_callable self.aslist = get_aslist(self.param_or_header) self.explode = get_explode(self.param_or_header) - def __call__(self, value): + def __call__(self, value: Any) -> Any: # if "in" not defined then it's a Header if "allowEmptyValue" in self.param_or_header: warnings.warn( diff --git a/openapi_core/deserializing/parameters/exceptions.py b/openapi_core/deserializing/parameters/exceptions.py index 64dbe910..146d60a1 100644 --- a/openapi_core/deserializing/parameters/exceptions.py +++ b/openapi_core/deserializing/parameters/exceptions.py @@ -17,7 +17,7 @@ class ParameterDeserializeError(BaseParameterDeserializeError): style: str value: str - def __str__(self): + def __str__(self) -> str: return ( "Failed to deserialize value of " f"{self.location} parameter with style {self.style}: {self.value}" @@ -28,11 +28,11 @@ def __str__(self): class EmptyQueryParameterValue(BaseParameterDeserializeError): name: str - def __init__(self, name): + def __init__(self, name: str): super().__init__(location="query") self.name = name - def __str__(self): + def __str__(self) -> str: return ( f"Value of {self.name} {self.location} parameter cannot be empty" ) diff --git a/openapi_core/deserializing/parameters/factories.py b/openapi_core/deserializing/parameters/factories.py index f72825b2..f937446f 100644 --- a/openapi_core/deserializing/parameters/factories.py +++ b/openapi_core/deserializing/parameters/factories.py @@ -1,5 +1,12 @@ from functools import partial +from typing import Dict +from openapi_core.deserializing.parameters.datatypes import ( + DeserializerCallable, +) +from openapi_core.deserializing.parameters.deserializers import ( + BaseParameterDeserializer, +) from openapi_core.deserializing.parameters.deserializers import ( CallableParameterDeserializer, ) @@ -8,18 +15,19 @@ ) from openapi_core.deserializing.parameters.util import split from openapi_core.schema.parameters import get_style +from openapi_core.spec import Spec class ParameterDeserializersFactory: - PARAMETER_STYLE_DESERIALIZERS = { + PARAMETER_STYLE_DESERIALIZERS: Dict[str, DeserializerCallable] = { "form": partial(split, separator=","), "simple": partial(split, separator=","), "spaceDelimited": partial(split, separator=" "), "pipeDelimited": partial(split, separator="|"), } - def create(self, param_or_header): + def create(self, param_or_header: Spec) -> BaseParameterDeserializer: style = get_style(param_or_header) if style not in self.PARAMETER_STYLE_DESERIALIZERS: diff --git a/openapi_core/deserializing/parameters/util.py b/openapi_core/deserializing/parameters/util.py index e9cc4db0..1f484f21 100644 --- a/openapi_core/deserializing/parameters/util.py +++ b/openapi_core/deserializing/parameters/util.py @@ -1,2 +1,5 @@ -def split(value, separator=","): +from typing import List + + +def split(value: str, separator: str = ",") -> List[str]: return value.split(separator) diff --git a/openapi_core/extensions/models/factories.py b/openapi_core/extensions/models/factories.py index 1e66c128..af6074f1 100644 --- a/openapi_core/extensions/models/factories.py +++ b/openapi_core/extensions/models/factories.py @@ -1,4 +1,9 @@ """OpenAPI X-Model extension factories module""" +from typing import Any +from typing import Dict +from typing import Optional +from typing import Type + from openapi_core.extensions.models.models import Model @@ -6,19 +11,23 @@ class ModelClassFactory: base_class = Model - def create(self, name): + def create(self, name: str) -> Type[Model]: return type(name, (self.base_class,), {}) class ModelFactory: - def __init__(self, model_class_factory=None): + def __init__( + self, model_class_factory: Optional[ModelClassFactory] = None + ): self.model_class_factory = model_class_factory or ModelClassFactory() - def create(self, properties, name=None): + def create( + self, properties: Optional[Dict[str, Any]], name: Optional[str] = None + ) -> Model: name = name or "Model" model_class = self._create_class(name) return model_class(properties) - def _create_class(self, name): + def _create_class(self, name: str) -> Type[Model]: return self.model_class_factory.create(name) diff --git a/openapi_core/extensions/models/models.py b/openapi_core/extensions/models/models.py index a1080dd7..c27abf15 100644 --- a/openapi_core/extensions/models/models.py +++ b/openapi_core/extensions/models/models.py @@ -1,25 +1,28 @@ """OpenAPI X-Model extension models module""" +from typing import Any +from typing import Dict +from typing import Optional class BaseModel: """Base class for OpenAPI X-Model.""" @property - def __dict__(self): + def __dict__(self) -> Dict[Any, Any]: # type: ignore raise NotImplementedError class Model(BaseModel): """Model class for OpenAPI X-Model.""" - def __init__(self, properties=None): + def __init__(self, properties: Optional[Dict[str, Any]] = None): self.__properties = properties or {} @property - def __dict__(self): + def __dict__(self) -> Dict[Any, Any]: # type: ignore return self.__properties - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name not in self.__properties: raise AttributeError diff --git a/openapi_core/schema/parameters.py b/openapi_core/schema/parameters.py index c44dc2e3..30195c67 100644 --- a/openapi_core/schema/parameters.py +++ b/openapi_core/schema/parameters.py @@ -1,7 +1,16 @@ -from itertools import chain +from typing import Any +from typing import Dict +from typing import Optional +from typing import Union +from werkzeug.datastructures import Headers -def get_aslist(param_or_header): +from openapi_core.schema.protocols import SuportsGetAll +from openapi_core.schema.protocols import SuportsGetList +from openapi_core.spec import Spec + + +def get_aslist(param_or_header: Spec) -> bool: """Checks if parameter/header is described as list for simpler scenarios""" # if schema is not defined it's a complex scenario if "schema" not in param_or_header: @@ -13,9 +22,10 @@ def get_aslist(param_or_header): return schema_type in ["array", "object"] -def get_style(param_or_header): +def get_style(param_or_header: Spec) -> str: """Checks parameter/header style for simpler scenarios""" if "style" in param_or_header: + assert isinstance(param_or_header["style"], str) return param_or_header["style"] # if "in" not defined then it's a Header @@ -25,9 +35,10 @@ def get_style(param_or_header): return "simple" if location in ["path", "header"] else "form" -def get_explode(param_or_header): +def get_explode(param_or_header: Spec) -> bool: """Checks parameter/header explode for simpler scenarios""" if "explode" in param_or_header: + assert isinstance(param_or_header["explode"], bool) return param_or_header["explode"] # determine default @@ -35,7 +46,11 @@ def get_explode(param_or_header): return style == "form" -def get_value(param_or_header, location, name=None): +def get_value( + param_or_header: Spec, + location: Union[Headers, Dict[str, Any]], + name: Optional[str] = None, +) -> Any: """Returns parameter/header value from specific location""" name = name or param_or_header["name"] @@ -45,13 +60,9 @@ def get_value(param_or_header, location, name=None): aslist = get_aslist(param_or_header) explode = get_explode(param_or_header) if aslist and explode: - if hasattr(location, "getall"): + if isinstance(location, SuportsGetAll): return location.getall(name) - return location.getlist(name) + if isinstance(location, SuportsGetList): + return location.getlist(name) return location[name] - - -def iter_params(*lists): - iters = map(lambda l: l and iter(l) or [], lists) - return chain(*iters) diff --git a/openapi_core/schema/protocols.py b/openapi_core/schema/protocols.py new file mode 100644 index 00000000..a675db5c --- /dev/null +++ b/openapi_core/schema/protocols.py @@ -0,0 +1,26 @@ +from typing import TYPE_CHECKING +from typing import Any +from typing import List + +if TYPE_CHECKING: + from typing_extensions import Protocol + from typing_extensions import runtime_checkable +else: + try: + from typing import Protocol + from typing import runtime_checkable + except ImportError: + from typing_extensions import Protocol + from typing_extensions import runtime_checkable + + +@runtime_checkable +class SuportsGetAll(Protocol): + def getall(self, name: str) -> List[Any]: + ... + + +@runtime_checkable +class SuportsGetList(Protocol): + def getlist(self, name: str) -> List[Any]: + ... diff --git a/openapi_core/schema/schemas.py b/openapi_core/schema/schemas.py index a4f1bf1b..b7737374 100644 --- a/openapi_core/schema/schemas.py +++ b/openapi_core/schema/schemas.py @@ -1,4 +1,11 @@ -def get_all_properties(schema): +from typing import Any +from typing import Dict +from typing import Set + +from openapi_core.spec import Spec + + +def get_all_properties(schema: Spec) -> Dict[str, Any]: properties = schema.get("properties", {}) properties_dict = dict(list(properties.items())) @@ -12,6 +19,6 @@ def get_all_properties(schema): return properties_dict -def get_all_properties_names(schema): +def get_all_properties_names(schema: Spec) -> Set[str]: all_properties = get_all_properties(schema) return set(all_properties.keys()) diff --git a/openapi_core/schema/servers.py b/openapi_core/schema/servers.py index cabeabf4..e483f517 100644 --- a/openapi_core/schema/servers.py +++ b/openapi_core/schema/servers.py @@ -1,8 +1,14 @@ -def is_absolute(url): +from typing import Any +from typing import Dict + +from openapi_core.spec import Spec + + +def is_absolute(url: str) -> bool: return url.startswith("//") or "://" in url -def get_server_default_variables(server): +def get_server_default_variables(server: Spec) -> Dict[str, Any]: if "variables" not in server: return {} @@ -13,7 +19,8 @@ def get_server_default_variables(server): return defaults -def get_server_url(server, **variables): +def get_server_url(server: Spec, **variables: Any) -> str: if not variables: variables = get_server_default_variables(server) + assert isinstance(server["url"], str) return server["url"].format(**variables) diff --git a/openapi_core/schema/specs.py b/openapi_core/schema/specs.py index ab275734..5056a30d 100644 --- a/openapi_core/schema/specs.py +++ b/openapi_core/schema/specs.py @@ -1,6 +1,7 @@ from openapi_core.schema.servers import get_server_url +from openapi_core.spec import Spec -def get_spec_url(spec, index=0): +def get_spec_url(spec: Spec, index: int = 0) -> str: servers = spec / "servers" return get_server_url(servers / 0) diff --git a/openapi_core/security/factories.py b/openapi_core/security/factories.py index 65c1d91d..562f0c76 100644 --- a/openapi_core/security/factories.py +++ b/openapi_core/security/factories.py @@ -1,18 +1,24 @@ +from typing import Any +from typing import Dict +from typing import Type + from openapi_core.security.providers import ApiKeyProvider +from openapi_core.security.providers import BaseProvider from openapi_core.security.providers import HttpProvider from openapi_core.security.providers import UnsupportedProvider +from openapi_core.spec import Spec class SecurityProviderFactory: - PROVIDERS = { + PROVIDERS: Dict[str, Type[BaseProvider]] = { "apiKey": ApiKeyProvider, "http": HttpProvider, "oauth2": UnsupportedProvider, "openIdConnect": UnsupportedProvider, } - def create(self, scheme): + def create(self, scheme: Spec) -> Any: scheme_type = scheme["type"] provider_class = self.PROVIDERS[scheme_type] return provider_class(scheme) diff --git a/openapi_core/security/providers.py b/openapi_core/security/providers.py index 39403578..8ce79f7a 100644 --- a/openapi_core/security/providers.py +++ b/openapi_core/security/providers.py @@ -1,20 +1,26 @@ import warnings +from typing import Any from openapi_core.security.exceptions import SecurityError +from openapi_core.spec import Spec +from openapi_core.validation.request.protocols import Request class BaseProvider: - def __init__(self, scheme): + def __init__(self, scheme: Spec): self.scheme = scheme + def __call__(self, request: Request) -> Any: + raise NotImplementedError + class UnsupportedProvider(BaseProvider): - def __call__(self, request): + def __call__(self, request: Request) -> Any: warnings.warn("Unsupported scheme type") class ApiKeyProvider(BaseProvider): - def __call__(self, request): + def __call__(self, request: Request) -> Any: name = self.scheme["name"] location = self.scheme["in"] source = getattr(request.parameters, location) @@ -24,7 +30,7 @@ def __call__(self, request): class HttpProvider(BaseProvider): - def __call__(self, request): + def __call__(self, request: Request) -> Any: if "Authorization" not in request.parameters.header: raise SecurityError("Missing authorization header.") auth_header = request.parameters.header["Authorization"] diff --git a/openapi_core/spec/accessors.py b/openapi_core/spec/accessors.py index 034cf18a..9c8b7012 100644 --- a/openapi_core/spec/accessors.py +++ b/openapi_core/spec/accessors.py @@ -1,15 +1,26 @@ from contextlib import contextmanager +from typing import Any +from typing import Hashable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Union +from openapi_spec_validator.validators import Dereferencer from pathable.accessors import LookupAccessor class SpecAccessor(LookupAccessor): - def __init__(self, lookup, dereferencer): + def __init__( + self, lookup: Mapping[Hashable, Any], dereferencer: Dereferencer + ): super().__init__(lookup) self.dereferencer = dereferencer @contextmanager - def open(self, parts): + def open( + self, parts: List[Hashable] + ) -> Iterator[Union[Mapping[Hashable, Any], Any]]: content = self.lookup for part in parts: content = content[part] diff --git a/openapi_core/spec/paths.py b/openapi_core/spec/paths.py index 36b41f85..ea5ce28b 100644 --- a/openapi_core/spec/paths.py +++ b/openapi_core/spec/paths.py @@ -1,3 +1,9 @@ +from typing import Any +from typing import Dict +from typing import Hashable +from typing import Mapping + +from jsonschema.protocols import Validator from jsonschema.validators import RefResolver from openapi_spec_validator import default_handlers from openapi_spec_validator import openapi_v3_spec_validator @@ -13,12 +19,12 @@ class Spec(AccessorPath): @classmethod def from_dict( cls, - data, - *args, - url="", - ref_resolver_handlers=default_handlers, - separator=SPEC_SEPARATOR, - ): + data: Mapping[Hashable, Any], + *args: Any, + url: str = "", + ref_resolver_handlers: Dict[str, Any] = default_handlers, + separator: str = SPEC_SEPARATOR, + ) -> "Spec": ref_resolver = RefResolver(url, data, handlers=ref_resolver_handlers) dereferencer = Dereferencer(ref_resolver) accessor = SpecAccessor(data, dereferencer) @@ -27,13 +33,13 @@ def from_dict( @classmethod def create( cls, - data, - *args, - url="", - ref_resolver_handlers=default_handlers, - separator=SPEC_SEPARATOR, - validator=openapi_v3_spec_validator, - ): + data: Mapping[Hashable, Any], + *args: Any, + url: str = "", + ref_resolver_handlers: Dict[str, Any] = default_handlers, + separator: str = SPEC_SEPARATOR, + validator: Validator = openapi_v3_spec_validator, + ) -> "Spec": if validator is not None: validator.validate(data, spec_url=url) diff --git a/openapi_core/spec/shortcuts.py b/openapi_core/spec/shortcuts.py index 093c5ab3..aad0511e 100644 --- a/openapi_core/spec/shortcuts.py +++ b/openapi_core/spec/shortcuts.py @@ -1,21 +1,30 @@ """OpenAPI core spec shortcuts module""" +from typing import Any +from typing import Dict +from typing import Hashable +from typing import Mapping + from jsonschema.validators import RefResolver from openapi_spec_validator import default_handlers from openapi_spec_validator import openapi_v3_spec_validator from openapi_spec_validator.validators import Dereferencer -from openapi_core.spec.paths import SpecPath +from openapi_core.spec.paths import Spec def create_spec( - spec_dict, - spec_url="", - handlers=default_handlers, - validate_spec=True, -): + spec_dict: Mapping[Hashable, Any], + spec_url: str = "", + handlers: Dict[str, Any] = default_handlers, + validate_spec: bool = True, +) -> Spec: + validator = None if validate_spec: - openapi_v3_spec_validator.validate(spec_dict, spec_url=spec_url) + validator = openapi_v3_spec_validator - spec_resolver = RefResolver(spec_url, spec_dict, handlers=handlers) - dereferencer = Dereferencer(spec_resolver) - return SpecPath.from_spec(spec_dict, dereferencer) + return Spec.create( + spec_dict, + url=spec_url, + ref_resolver_handlers=handlers, + validator=validator, + ) diff --git a/openapi_core/templating/datatypes.py b/openapi_core/templating/datatypes.py index 02d4424b..68aa8a58 100644 --- a/openapi_core/templating/datatypes.py +++ b/openapi_core/templating/datatypes.py @@ -5,11 +5,11 @@ @dataclass class TemplateResult: - pattern: Optional[str] = None - variables: Optional[Dict] = None + pattern: str + variables: Optional[Dict[str, str]] = None @property - def resolved(self): + def resolved(self) -> str: if not self.variables: return self.pattern return self.pattern.format(**self.variables) diff --git a/openapi_core/templating/media_types/datatypes.py b/openapi_core/templating/media_types/datatypes.py new file mode 100644 index 00000000..d76fe9d2 --- /dev/null +++ b/openapi_core/templating/media_types/datatypes.py @@ -0,0 +1,3 @@ +from collections import namedtuple + +MediaType = namedtuple("MediaType", ["value", "key"]) diff --git a/openapi_core/templating/media_types/exceptions.py b/openapi_core/templating/media_types/exceptions.py index 26c46596..190d349e 100644 --- a/openapi_core/templating/media_types/exceptions.py +++ b/openapi_core/templating/media_types/exceptions.py @@ -13,7 +13,7 @@ class MediaTypeNotFound(MediaTypeFinderError): mimetype: str availableMimetypes: List[str] - def __str__(self): + def __str__(self) -> str: return ( f"Content for the following mimetype not found: {self.mimetype}. " f"Valid mimetypes: {self.availableMimetypes}" diff --git a/openapi_core/templating/media_types/finders.py b/openapi_core/templating/media_types/finders.py index 89a379ba..b7be6a4d 100644 --- a/openapi_core/templating/media_types/finders.py +++ b/openapi_core/templating/media_types/finders.py @@ -1,20 +1,22 @@ """OpenAPI core templating media types finders module""" import fnmatch +from openapi_core.spec import Spec +from openapi_core.templating.media_types.datatypes import MediaType from openapi_core.templating.media_types.exceptions import MediaTypeNotFound class MediaTypeFinder: - def __init__(self, content): + def __init__(self, content: Spec): self.content = content - def find(self, mimetype): + def find(self, mimetype: str) -> MediaType: if mimetype in self.content: - return self.content / mimetype, mimetype + return MediaType(self.content / mimetype, mimetype) if mimetype: for key, value in self.content.items(): if fnmatch.fnmatch(mimetype, key): - return value, key + return MediaType(value, key) raise MediaTypeNotFound(mimetype, list(self.content.keys())) diff --git a/openapi_core/templating/paths/datatypes.py b/openapi_core/templating/paths/datatypes.py new file mode 100644 index 00000000..31d4a4e4 --- /dev/null +++ b/openapi_core/templating/paths/datatypes.py @@ -0,0 +1,11 @@ +"""OpenAPI core templating paths datatypes module""" +from collections import namedtuple + +Path = namedtuple("Path", ["path", "path_result"]) +OperationPath = namedtuple( + "OperationPath", ["path", "operation", "path_result"] +) +ServerOperationPath = namedtuple( + "ServerOperationPath", + ["path", "operation", "server", "path_result", "server_result"], +) diff --git a/openapi_core/templating/paths/exceptions.py b/openapi_core/templating/paths/exceptions.py index ec9fe4b3..4e38c480 100644 --- a/openapi_core/templating/paths/exceptions.py +++ b/openapi_core/templating/paths/exceptions.py @@ -13,7 +13,7 @@ class PathNotFound(PathError): url: str - def __str__(self): + def __str__(self) -> str: return f"Path not found for {self.url}" @@ -24,7 +24,7 @@ class OperationNotFound(PathError): url: str method: str - def __str__(self): + def __str__(self) -> str: return f"Operation {self.method} not found for {self.url}" @@ -34,5 +34,5 @@ class ServerNotFound(PathError): url: str - def __str__(self): + def __str__(self) -> str: return f"Server not found for {self.url}" diff --git a/openapi_core/templating/paths/finders.py b/openapi_core/templating/paths/finders.py index b95f27d7..377ff68d 100644 --- a/openapi_core/templating/paths/finders.py +++ b/openapi_core/templating/paths/finders.py @@ -1,11 +1,18 @@ """OpenAPI core templating paths finders module""" +from typing import Iterator +from typing import List +from typing import Optional from urllib.parse import urljoin from urllib.parse import urlparse from more_itertools import peekable from openapi_core.schema.servers import is_absolute +from openapi_core.spec import Spec from openapi_core.templating.datatypes import TemplateResult +from openapi_core.templating.paths.datatypes import OperationPath +from openapi_core.templating.paths.datatypes import Path +from openapi_core.templating.paths.datatypes import ServerOperationPath from openapi_core.templating.paths.exceptions import OperationNotFound from openapi_core.templating.paths.exceptions import PathNotFound from openapi_core.templating.paths.exceptions import ServerNotFound @@ -15,11 +22,17 @@ class PathFinder: - def __init__(self, spec, base_url=None): + def __init__(self, spec: Spec, base_url: Optional[str] = None): self.spec = spec self.base_url = base_url - def find(self, method, host_url, path, path_pattern=None): + def find( + self, + method: str, + host_url: str, + path: str, + path_pattern: Optional[str] = None, + ) -> ServerOperationPath: if path_pattern is not None: full_url = urljoin(host_url, path_pattern) else: @@ -47,34 +60,37 @@ def find(self, method, host_url, path, path_pattern=None): except StopIteration: raise ServerNotFound(full_url) - def _get_paths_iter(self, full_url): - template_paths = [] + def _get_paths_iter(self, full_url: str) -> Iterator[Path]: + template_paths: List[Path] = [] paths = self.spec / "paths" for path_pattern, path in list(paths.items()): # simple path. # Return right away since it is always the most concrete if full_url.endswith(path_pattern): path_result = TemplateResult(path_pattern, {}) - yield (path, path_result) + yield Path(path, path_result) # template path else: result = search(path_pattern, full_url) if result: path_result = TemplateResult(path_pattern, result.named) - template_paths.append((path, path_result)) + template_paths.append(Path(path, path_result)) # Fewer variables -> more concrete path - for path in sorted(template_paths, key=template_path_len): - yield path + yield from sorted(template_paths, key=template_path_len) - def _get_operations_iter(self, paths_iter, request_method): + def _get_operations_iter( + self, paths_iter: Iterator[Path], request_method: str + ) -> Iterator[OperationPath]: for path, path_result in paths_iter: if request_method not in path: continue operation = path / request_method - yield (path, operation, path_result) + yield OperationPath(path, operation, path_result) - def _get_servers_iter(self, operations_iter, full_url): + def _get_servers_iter( + self, operations_iter: Iterator[OperationPath], full_url: str + ) -> Iterator[ServerOperationPath]: for path, operation, path_result in operations_iter: servers = ( path.get("servers", None) @@ -98,7 +114,7 @@ def _get_servers_iter(self, operations_iter, full_url): # simple path if server_url_pattern == server_url: server_result = TemplateResult(server["url"], {}) - yield ( + yield ServerOperationPath( path, operation, server, @@ -112,7 +128,7 @@ def _get_servers_iter(self, operations_iter, full_url): server_result = TemplateResult( server["url"], result.named ) - yield ( + yield ServerOperationPath( path, operation, server, diff --git a/openapi_core/templating/paths/util.py b/openapi_core/templating/paths/util.py index ba0f5799..a89c6d3b 100644 --- a/openapi_core/templating/paths/util.py +++ b/openapi_core/templating/paths/util.py @@ -1,8 +1,8 @@ from typing import Tuple from openapi_core.spec.paths import Spec -from openapi_core.templating.datatypes import TemplateResult +from openapi_core.templating.paths.datatypes import Path -def template_path_len(template_path: Tuple[Spec, TemplateResult]) -> int: +def template_path_len(template_path: Path) -> int: return len(template_path[1].variables) diff --git a/openapi_core/templating/responses/exceptions.py b/openapi_core/templating/responses/exceptions.py index 6ba282d0..39e1a012 100644 --- a/openapi_core/templating/responses/exceptions.py +++ b/openapi_core/templating/responses/exceptions.py @@ -12,8 +12,8 @@ class ResponseFinderError(OpenAPIError): class ResponseNotFound(ResponseFinderError): """Find response error""" - http_status: int + http_status: str availableresponses: List[str] - def __str__(self): + def __str__(self) -> str: return f"Unknown response http status: {str(self.http_status)}" diff --git a/openapi_core/templating/responses/finders.py b/openapi_core/templating/responses/finders.py index 87446748..c78f170a 100644 --- a/openapi_core/templating/responses/finders.py +++ b/openapi_core/templating/responses/finders.py @@ -1,11 +1,12 @@ +from openapi_core.spec import Spec from openapi_core.templating.responses.exceptions import ResponseNotFound class ResponseFinder: - def __init__(self, responses): + def __init__(self, responses: Spec): self.responses = responses - def find(self, http_status="default"): + def find(self, http_status: str = "default") -> Spec: if http_status in self.responses: return self.responses / http_status diff --git a/openapi_core/templating/util.py b/openapi_core/templating/util.py index d3d4fcc6..fa878ad8 100644 --- a/openapi_core/templating/util.py +++ b/openapi_core/templating/util.py @@ -1,8 +1,12 @@ +from typing import Any +from typing import Optional + +from parse import Match from parse import Parser -class ExtendedParser(Parser): - def _handle_field(self, field): +class ExtendedParser(Parser): # type: ignore + def _handle_field(self, field: str) -> Any: # handle as path parameter field field = field[1:-1] path_parameter_field = "{%s:PathParameter}" % field @@ -14,21 +18,21 @@ class PathParameter: name = "PathParameter" pattern = r"[^\/]+" - def __call__(self, text): + def __call__(self, text: str) -> str: return text parse_path_parameter = PathParameter() -def search(path_pattern, full_url_pattern): +def search(path_pattern: str, full_url_pattern: str) -> Optional[Match]: extra_types = {parse_path_parameter.name: parse_path_parameter} p = ExtendedParser(path_pattern, extra_types) p._expression = p._expression + "$" return p.search(full_url_pattern) -def parse(server_url, server_url_pattern): +def parse(server_url: str, server_url_pattern: str) -> Match: extra_types = {parse_path_parameter.name: parse_path_parameter} p = ExtendedParser(server_url, extra_types) p._expression = "^" + p._expression diff --git a/openapi_core/testing/datatypes.py b/openapi_core/testing/datatypes.py index 7bf38e8d..7bdc3a0e 100644 --- a/openapi_core/testing/datatypes.py +++ b/openapi_core/testing/datatypes.py @@ -1,18 +1,21 @@ +from typing import Optional + +from openapi_core.validation.request.datatypes import Parameters + + class ResultMock: def __init__( - self, body=None, parameters=None, data=None, error_to_raise=None + self, + body: Optional[str] = None, + parameters: Optional[Parameters] = None, + data: Optional[str] = None, + error_to_raise: Optional[Exception] = None, ): self.body = body self.parameters = parameters self.data = data self.error_to_raise = error_to_raise - def raise_for_errors(self): + def raise_for_errors(self) -> None: if self.error_to_raise is not None: raise self.error_to_raise - - if self.parameters is not None: - return self.parameters - - if self.data is not None: - return self.data diff --git a/openapi_core/testing/requests.py b/openapi_core/testing/requests.py index e1041cc4..9df4827c 100644 --- a/openapi_core/testing/requests.py +++ b/openapi_core/testing/requests.py @@ -1,4 +1,8 @@ """OpenAPI core testing requests module""" +from typing import Any +from typing import Dict +from typing import Optional + from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict @@ -8,16 +12,16 @@ class MockRequest: def __init__( self, - host_url, - method, - path, - path_pattern=None, - args=None, - view_args=None, - headers=None, - cookies=None, - data=None, - mimetype="application/json", + host_url: str, + method: str, + path: str, + path_pattern: Optional[str] = None, + args: Optional[Dict[str, Any]] = None, + view_args: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + cookies: Optional[Dict[str, Any]] = None, + data: Optional[str] = None, + mimetype: str = "application/json", ): self.host_url = host_url self.method = method.lower() diff --git a/openapi_core/testing/responses.py b/openapi_core/testing/responses.py index d414a28e..de352507 100644 --- a/openapi_core/testing/responses.py +++ b/openapi_core/testing/responses.py @@ -1,10 +1,18 @@ """OpenAPI core testing responses module""" +from typing import Any +from typing import Dict +from typing import Optional + from werkzeug.datastructures import Headers class MockResponse: def __init__( - self, data, status_code=200, headers=None, mimetype="application/json" + self, + data: str, + status_code: int = 200, + headers: Optional[Dict[str, Any]] = None, + mimetype: str = "application/json", ): self.data = data self.status_code = status_code diff --git a/openapi_core/unmarshalling/schemas/datatypes.py b/openapi_core/unmarshalling/schemas/datatypes.py new file mode 100644 index 00000000..96008373 --- /dev/null +++ b/openapi_core/unmarshalling/schemas/datatypes.py @@ -0,0 +1,7 @@ +from typing import Dict +from typing import Optional + +from openapi_core.unmarshalling.schemas.formatters import Formatter + +CustomFormattersDict = Dict[str, Formatter] +FormattersDict = Dict[Optional[str], Formatter] diff --git a/openapi_core/unmarshalling/schemas/exceptions.py b/openapi_core/unmarshalling/schemas/exceptions.py index 8df84c12..2d6fafad 100644 --- a/openapi_core/unmarshalling/schemas/exceptions.py +++ b/openapi_core/unmarshalling/schemas/exceptions.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from dataclasses import field -from typing import List +from typing import Iterable from openapi_core.exceptions import OpenAPIError @@ -21,9 +21,9 @@ class UnmarshallerError(UnmarshalError): class InvalidSchemaValue(ValidateError): value: str type: str - schema_errors: List[Exception] = field(default_factory=list) + schema_errors: Iterable[Exception] = field(default_factory=list) - def __str__(self): + def __str__(self) -> str: return ( "Value {value} not valid for schema of type {type}: {errors}" ).format(value=self.value, type=self.type, errors=self.schema_errors) @@ -37,7 +37,7 @@ class InvalidSchemaFormatValue(UnmarshallerError): type: str original_exception: Exception - def __str__(self): + def __str__(self) -> str: return ( "Failed to format value {value} to format {type}: {exception}" ).format( @@ -53,5 +53,5 @@ class FormatterNotFoundError(UnmarshallerError): type_format: str - def __str__(self): + def __str__(self) -> str: return f"Formatter not found for {self.type_format} format" diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index ad7985d6..e8ed5203 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -1,16 +1,32 @@ import warnings +from typing import Any +from typing import Dict +from typing import Optional +from typing import Type +from typing import Union +from jsonschema.protocols import Validator from openapi_schema_validator import OAS30Validator +from openapi_core.spec import Spec +from openapi_core.unmarshalling.schemas.datatypes import CustomFormattersDict +from openapi_core.unmarshalling.schemas.datatypes import FormattersDict from openapi_core.unmarshalling.schemas.enums import UnmarshalContext from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) +from openapi_core.unmarshalling.schemas.formatters import Formatter from openapi_core.unmarshalling.schemas.unmarshallers import AnyUnmarshaller from openapi_core.unmarshalling.schemas.unmarshallers import ArrayUnmarshaller +from openapi_core.unmarshalling.schemas.unmarshallers import ( + BaseSchemaUnmarshaller, +) from openapi_core.unmarshalling.schemas.unmarshallers import ( BooleanUnmarshaller, ) +from openapi_core.unmarshalling.schemas.unmarshallers import ( + ComplexUnmarshaller, +) from openapi_core.unmarshalling.schemas.unmarshallers import ( IntegerUnmarshaller, ) @@ -22,7 +38,7 @@ class SchemaUnmarshallersFactory: - UNMARSHALLERS = { + UNMARSHALLERS: Dict[str, Type[BaseSchemaUnmarshaller]] = { "string": StringUnmarshaller, "integer": IntegerUnmarshaller, "number": NumberUnmarshaller, @@ -32,7 +48,11 @@ class SchemaUnmarshallersFactory: "any": AnyUnmarshaller, } - COMPLEX_UNMARSHALLERS = ["array", "object", "any"] + COMPLEX_UNMARSHALLERS: Dict[str, Type[ComplexUnmarshaller]] = { + "array": ArrayUnmarshaller, + "object": ObjectUnmarshaller, + "any": AnyUnmarshaller, + } CONTEXT_VALIDATION = { UnmarshalContext.REQUEST: "write", @@ -41,9 +61,9 @@ class SchemaUnmarshallersFactory: def __init__( self, - schema_validator_class, - custom_formatters=None, - context=None, + schema_validator_class: Type[Validator], + custom_formatters: Optional[CustomFormattersDict] = None, + context: Optional[UnmarshalContext] = None, ): self.schema_validator_class = schema_validator_class if custom_formatters is None: @@ -51,7 +71,9 @@ def __init__( self.custom_formatters = custom_formatters self.context = context - def create(self, schema, type_override=None): + def create( + self, schema: Spec, type_override: Optional[str] = None + ) -> BaseSchemaUnmarshaller: """Create unmarshaller from the schema.""" if schema is None: raise TypeError("Invalid schema") @@ -59,34 +81,36 @@ def create(self, schema, type_override=None): if schema.getkey("deprecated", False): warnings.warn("The schema is deprecated", DeprecationWarning) - schema_type = type_override or schema.getkey("type", "any") - schema_format = schema.getkey("format") - - klass = self.UNMARSHALLERS[schema_type] - - formatter = self.get_formatter(schema_format, klass.FORMATTERS) - if formatter is None: - raise FormatterNotFoundError(schema_format) - validator = self.get_validator(schema) - kwargs = dict() + schema_format = schema.getkey("format") + formatter = self.custom_formatters.get(schema_format) + + schema_type = type_override or schema.getkey("type", "any") if schema_type in self.COMPLEX_UNMARSHALLERS: - kwargs.update( - unmarshallers_factory=self, - context=self.context, + complex_klass = self.COMPLEX_UNMARSHALLERS[schema_type] + return complex_klass( + schema, validator, formatter, self, context=self.context ) - return klass(schema, formatter, validator, **kwargs) - def get_formatter(self, type_format, default_formatters): + klass = self.UNMARSHALLERS[schema_type] + return klass(schema, validator, formatter) + + def get_formatter( + self, type_format: str, default_formatters: FormattersDict + ) -> Optional[Formatter]: try: return self.custom_formatters[type_format] except KeyError: return default_formatters.get(type_format) - def get_validator(self, schema): - resolver = schema.accessor.dereferencer.resolver_manager.resolver - format_checker = build_format_checker(**self.custom_formatters) + def get_validator(self, schema: Spec) -> Validator: + resolver = schema.accessor.dereferencer.resolver_manager.resolver # type: ignore + custom_format_checks = { + name: formatter.validate + for name, formatter in self.custom_formatters.items() + } + format_checker = build_format_checker(**custom_format_checks) kwargs = { "resolver": resolver, "format_checker": format_checker, diff --git a/openapi_core/unmarshalling/schemas/formatters.py b/openapi_core/unmarshalling/schemas/formatters.py index cbb8776b..47dd52b8 100644 --- a/openapi_core/unmarshalling/schemas/formatters.py +++ b/openapi_core/unmarshalling/schemas/formatters.py @@ -1,17 +1,27 @@ +from typing import Any +from typing import Callable +from typing import Optional +from typing import Type + + class Formatter: - def validate(self, value): + def validate(self, value: Any) -> bool: return True - def unmarshal(self, value): + def unmarshal(self, value: Any) -> Any: return value @classmethod - def from_callables(cls, validate=None, unmarshal=None): + def from_callables( + cls, + validate: Optional[Callable[[Any], Any]] = None, + unmarshal: Optional[Callable[[Any], Any]] = None, + ) -> "Formatter": attrs = {} if validate is not None: attrs["validate"] = staticmethod(validate) if unmarshal is not None: attrs["unmarshal"] = staticmethod(unmarshal) - klass = type("Formatter", (cls,), attrs) + klass: Type[Formatter] = type("Formatter", (cls,), attrs) return klass() diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index bec882a4..205e957a 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -1,7 +1,13 @@ import logging from functools import partial +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import List +from typing import Optional from isodate.isodatetime import parse_datetime +from jsonschema.protocols import Validator from openapi_schema_validator._format import oas30_format_checker from openapi_schema_validator._types import is_array from openapi_schema_validator._types import is_bool @@ -13,7 +19,12 @@ from openapi_core.extensions.models.factories import ModelFactory from openapi_core.schema.schemas import get_all_properties from openapi_core.schema.schemas import get_all_properties_names +from openapi_core.spec import Spec +from openapi_core.unmarshalling.schemas.datatypes import FormattersDict from openapi_core.unmarshalling.schemas.enums import UnmarshalContext +from openapi_core.unmarshalling.schemas.exceptions import ( + FormatterNotFoundError, +) from openapi_core.unmarshalling.schemas.exceptions import ( InvalidSchemaFormatValue, ) @@ -27,19 +38,38 @@ from openapi_core.unmarshalling.schemas.util import format_uuid from openapi_core.util import forcebool +if TYPE_CHECKING: + from openapi_core.unmarshalling.schemas.factories import ( + SchemaUnmarshallersFactory, + ) + log = logging.getLogger(__name__) class BaseSchemaUnmarshaller: - FORMATTERS = { + FORMATTERS: FormattersDict = { None: Formatter(), } - def __init__(self, schema): + def __init__( + self, + schema: Spec, + validator: Validator, + formatter: Optional[Formatter], + ): self.schema = schema + self.validator = validator + self.format = schema.getkey("format") - def __call__(self, value): + if formatter is None: + if self.format not in self.FORMATTERS: + raise FormatterNotFoundError(self.format) + self.formatter = self.FORMATTERS[self.format] + else: + self.formatter = formatter + + def __call__(self, value: Any) -> Any: if value is None: return @@ -47,43 +77,29 @@ def __call__(self, value): return self.unmarshal(value) - def validate(self, value): - raise NotImplementedError - - def unmarshal(self, value): - raise NotImplementedError - - -class PrimitiveTypeUnmarshaller(BaseSchemaUnmarshaller): - def __init__(self, schema, formatter, validator): - super().__init__(schema) - self.formatter = formatter - self.validator = validator - - def _formatter_validate(self, value): + def _formatter_validate(self, value: Any) -> None: result = self.formatter.validate(value) if not result: schema_type = self.schema.getkey("type", "any") raise InvalidSchemaValue(value, schema_type) - def validate(self, value): + def validate(self, value: Any) -> None: errors_iter = self.validator.iter_errors(value) errors = tuple(errors_iter) if errors: schema_type = self.schema.getkey("type", "any") raise InvalidSchemaValue(value, schema_type, schema_errors=errors) - def unmarshal(self, value): + def unmarshal(self, value: Any) -> Any: try: return self.formatter.unmarshal(value) except ValueError as exc: - schema_format = self.schema.getkey("format") - raise InvalidSchemaFormatValue(value, schema_format, exc) + raise InvalidSchemaFormatValue(value, self.format, exc) -class StringUnmarshaller(PrimitiveTypeUnmarshaller): +class StringUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS = { + FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_string, None), str), "password": Formatter.from_callables( partial(oas30_format_checker.check, format="password"), str @@ -107,9 +123,9 @@ class StringUnmarshaller(PrimitiveTypeUnmarshaller): } -class IntegerUnmarshaller(PrimitiveTypeUnmarshaller): +class IntegerUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS = { + FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_integer, None), int), "int32": Formatter.from_callables( partial(oas30_format_checker.check, format="int32"), int @@ -120,9 +136,9 @@ class IntegerUnmarshaller(PrimitiveTypeUnmarshaller): } -class NumberUnmarshaller(PrimitiveTypeUnmarshaller): +class NumberUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS = { + FORMATTERS: FormattersDict = { None: Formatter.from_callables( partial(is_number, None), format_number ), @@ -135,33 +151,38 @@ class NumberUnmarshaller(PrimitiveTypeUnmarshaller): } -class BooleanUnmarshaller(PrimitiveTypeUnmarshaller): +class BooleanUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS = { + FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_bool, None), forcebool), } -class ComplexUnmarshaller(PrimitiveTypeUnmarshaller): +class ComplexUnmarshaller(BaseSchemaUnmarshaller): def __init__( - self, schema, formatter, validator, unmarshallers_factory, context=None + self, + schema: Spec, + validator: Validator, + formatter: Optional[Formatter], + unmarshallers_factory: "SchemaUnmarshallersFactory", + context: Optional[UnmarshalContext] = None, ): - super().__init__(schema, formatter, validator) + super().__init__(schema, validator, formatter) self.unmarshallers_factory = unmarshallers_factory self.context = context class ArrayUnmarshaller(ComplexUnmarshaller): - FORMATTERS = { + FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_array, None), list), } @property - def items_unmarshaller(self): + def items_unmarshaller(self) -> "BaseSchemaUnmarshaller": return self.unmarshallers_factory.create(self.schema / "items") - def __call__(self, value): + def __call__(self, value: Any) -> Optional[List[Any]]: value = super().__call__(value) if value is None and self.schema.getkey("nullable", False): return None @@ -170,23 +191,24 @@ def __call__(self, value): class ObjectUnmarshaller(ComplexUnmarshaller): - FORMATTERS = { + FORMATTERS: FormattersDict = { None: Formatter.from_callables(partial(is_object, None), dict), } @property - def model_factory(self): + def model_factory(self) -> ModelFactory: return ModelFactory() - def unmarshal(self, value): + def unmarshal(self, value: Any) -> Any: try: value = self.formatter.unmarshal(value) except ValueError as exc: - raise InvalidSchemaFormatValue(value, self.schema.format, exc) + schema_format = self.schema.getkey("format") + raise InvalidSchemaFormatValue(value, schema_format, exc) else: return self._unmarshal_object(value) - def _unmarshal_object(self, value): + def _unmarshal_object(self, value: Any) -> Any: if "oneOf" in self.schema: properties = None for one_of_schema in self.schema / "oneOf": @@ -214,7 +236,9 @@ def _unmarshal_object(self, value): return properties - def _unmarshal_properties(self, value, one_of_schema=None): + def _unmarshal_properties( + self, value: Any, one_of_schema: Optional[Spec] = None + ) -> Dict[str, Any]: all_props = get_all_properties(self.schema) all_props_names = get_all_properties_names(self.schema) @@ -225,7 +249,7 @@ def _unmarshal_properties(self, value, one_of_schema=None): value_props_names = list(value.keys()) extra_props = set(value_props_names) - set(all_props_names) - properties = {} + properties: Dict[str, Any] = {} additional_properties = self.schema.getkey( "additionalProperties", True ) @@ -273,7 +297,7 @@ class AnyUnmarshaller(ComplexUnmarshaller): "string", ] - def unmarshal(self, value): + def unmarshal(self, value: Any) -> Any: one_of_schema = self._get_one_of_schema(value) if one_of_schema: return self.unmarshallers_factory.create(one_of_schema)(value) @@ -297,9 +321,9 @@ def unmarshal(self, value): log.warning("failed to unmarshal any type") return value - def _get_one_of_schema(self, value): + def _get_one_of_schema(self, value: Any) -> Optional[Spec]: if "oneOf" not in self.schema: - return + return None one_of_schemas = self.schema / "oneOf" for subschema in one_of_schemas: @@ -310,10 +334,11 @@ def _get_one_of_schema(self, value): continue else: return subschema + return None - def _get_all_of_schema(self, value): + def _get_all_of_schema(self, value: Any) -> Optional[Spec]: if "allOf" not in self.schema: - return + return None all_of_schemas = self.schema / "allOf" for subschema in all_of_schemas: @@ -326,3 +351,4 @@ def _get_all_of_schema(self, value): continue else: return subschema + return None diff --git a/openapi_core/unmarshalling/schemas/util.py b/openapi_core/unmarshalling/schemas/util.py index 74b61e38..ca240f48 100644 --- a/openapi_core/unmarshalling/schemas/util.py +++ b/openapi_core/unmarshalling/schemas/util.py @@ -1,28 +1,33 @@ """OpenAPI core schemas util module""" -import datetime from base64 import b64decode from copy import copy +from datetime import date +from datetime import datetime from functools import lru_cache +from typing import Any +from typing import Callable +from typing import Optional +from typing import Union from uuid import UUID from openapi_schema_validator import oas30_format_checker -def format_date(value): - return datetime.datetime.strptime(value, "%Y-%m-%d").date() +def format_date(value: str) -> date: + return datetime.strptime(value, "%Y-%m-%d").date() -def format_uuid(value): +def format_uuid(value: Any) -> UUID: if isinstance(value, UUID): return value return UUID(value) -def format_byte(value, encoding="utf8"): +def format_byte(value: str, encoding: str = "utf8") -> str: return str(b64decode(value), encoding) -def format_number(value): +def format_number(value: str) -> Union[int, float]: if isinstance(value, (int, float)): return value @@ -30,11 +35,11 @@ def format_number(value): @lru_cache() -def build_format_checker(**custom_formatters): - if not custom_formatters: +def build_format_checker(**custom_format_checks: Callable[[Any], Any]) -> Any: + if not custom_format_checks: return oas30_format_checker fc = copy(oas30_format_checker) - for name, formatter in list(custom_formatters.items()): - fc.checks(name)(formatter.validate) + for name, check in custom_format_checks.items(): + fc.checks(name)(check) return fc diff --git a/openapi_core/util.py b/openapi_core/util.py index 2a5ea1a5..cf551e24 100644 --- a/openapi_core/util.py +++ b/openapi_core/util.py @@ -1,5 +1,7 @@ """OpenAPI core util module""" +from itertools import chain from typing import Any +from typing import Iterable def forcebool(val: Any) -> bool: @@ -13,3 +15,8 @@ def forcebool(val: Any) -> bool: raise ValueError(f"invalid truth value {val!r}") return bool(val) + + +def chainiters(*lists: Iterable[Any]) -> Iterable[Any]: + iters = map(lambda l: l and iter(l) or [], lists) + return chain(*iters) diff --git a/openapi_core/validation/datatypes.py b/openapi_core/validation/datatypes.py index 1c34ef0c..5917bf43 100644 --- a/openapi_core/validation/datatypes.py +++ b/openapi_core/validation/datatypes.py @@ -1,12 +1,12 @@ """OpenAPI core validation datatypes module""" from dataclasses import dataclass -from typing import List +from typing import Iterable @dataclass class BaseValidationResult: - errors: List[Exception] + errors: Iterable[Exception] - def raise_for_errors(self): + def raise_for_errors(self) -> None: for error in self.errors: raise error diff --git a/openapi_core/validation/decorators.py b/openapi_core/validation/decorators.py deleted file mode 100644 index 9d8ce93c..00000000 --- a/openapi_core/validation/decorators.py +++ /dev/null @@ -1,62 +0,0 @@ -"""OpenAPI core validation decorators module""" -from functools import wraps - -from openapi_core.validation.processors import OpenAPIProcessor - - -class OpenAPIDecorator(OpenAPIProcessor): - def __init__( - self, - spec, - request_validator, - response_validator, - request_class, - response_class, - request_provider, - openapi_errors_handler, - ): - super().__init__(request_validator, response_validator) - self.spec = spec - self.request_class = request_class - self.response_class = response_class - self.request_provider = request_provider - self.openapi_errors_handler = openapi_errors_handler - - def __call__(self, view): - @wraps(view) - def decorated(*args, **kwargs): - request = self._get_request(*args, **kwargs) - openapi_request = self._get_openapi_request(request) - request_result = self.process_request(self.spec, openapi_request) - if request_result.errors: - return self._handle_request_errors(request_result) - response = self._handle_request_view( - request_result, view, *args, **kwargs - ) - openapi_response = self._get_openapi_response(response) - response_result = self.process_response( - self.spec, openapi_request, openapi_response - ) - if response_result.errors: - return self._handle_response_errors(response_result) - return response - - return decorated - - def _get_request(self, *args, **kwargs): - return self.request_provider.provide(*args, **kwargs) - - def _handle_request_view(self, request_result, view, *args, **kwargs): - return view(*args, **kwargs) - - def _handle_request_errors(self, request_result): - return self.openapi_errors_handler.handle(request_result.errors) - - def _handle_response_errors(self, response_result): - return self.openapi_errors_handler.handle(response_result.errors) - - def _get_openapi_request(self, request): - return self.request_class(request) - - def _get_openapi_response(self, response): - return self.response_class(response) diff --git a/openapi_core/validation/exceptions.py b/openapi_core/validation/exceptions.py index 2cc2b191..71b2bb87 100644 --- a/openapi_core/validation/exceptions.py +++ b/openapi_core/validation/exceptions.py @@ -10,7 +10,7 @@ class ValidationError(OpenAPIError): @dataclass class InvalidSecurity(ValidationError): - def __str__(self): + def __str__(self) -> str: return "Security not valid for any requirement" @@ -26,7 +26,7 @@ class MissingParameterError(OpenAPIParameterError): class MissingParameter(MissingParameterError): name: str - def __str__(self): + def __str__(self) -> str: return f"Missing parameter (without default value): {self.name}" @@ -34,7 +34,7 @@ def __str__(self): class MissingRequiredParameter(MissingParameterError): name: str - def __str__(self): + def __str__(self) -> str: return f"Missing required parameter: {self.name}" @@ -50,7 +50,7 @@ class MissingHeaderError(OpenAPIHeaderError): class MissingHeader(MissingHeaderError): name: str - def __str__(self): + def __str__(self) -> str: return f"Missing header (without default value): {self.name}" @@ -58,5 +58,5 @@ def __str__(self): class MissingRequiredHeader(MissingHeaderError): name: str - def __str__(self): + def __str__(self) -> str: return f"Missing required header: {self.name}" diff --git a/openapi_core/validation/processors.py b/openapi_core/validation/processors.py index abaf4974..13d393bc 100644 --- a/openapi_core/validation/processors.py +++ b/openapi_core/validation/processors.py @@ -1,13 +1,28 @@ """OpenAPI core validation processors module""" +from openapi_core.spec import Spec +from openapi_core.validation.request.datatypes import RequestValidationResult +from openapi_core.validation.request.protocols import Request +from openapi_core.validation.request.validators import RequestValidator +from openapi_core.validation.response.datatypes import ResponseValidationResult +from openapi_core.validation.response.protocols import Response +from openapi_core.validation.response.validators import ResponseValidator class OpenAPIProcessor: - def __init__(self, request_validator, response_validator): + def __init__( + self, + request_validator: RequestValidator, + response_validator: ResponseValidator, + ): self.request_validator = request_validator self.response_validator = response_validator - def process_request(self, spec, request): + def process_request( + self, spec: Spec, request: Request + ) -> RequestValidationResult: return self.request_validator.validate(spec, request) - def process_response(self, spec, request, response): + def process_response( + self, spec: Spec, request: Request, response: Response + ) -> ResponseValidationResult: return self.response_validator.validate(spec, request, response) diff --git a/openapi_core/validation/request/datatypes.py b/openapi_core/validation/request/datatypes.py index 067dc906..52fcbf67 100644 --- a/openapi_core/validation/request/datatypes.py +++ b/openapi_core/validation/request/datatypes.py @@ -1,6 +1,9 @@ """OpenAPI core validation request datatypes module""" +from __future__ import annotations + from dataclasses import dataclass from dataclasses import field +from typing import Any from typing import Dict from typing import Optional @@ -25,25 +28,29 @@ class RequestParameters: Path parameters as dict. Gets resolved against spec if empty. """ - query: ImmutableMultiDict = field(default_factory=ImmutableMultiDict) + query: ImmutableMultiDict[str, Any] = field( + default_factory=ImmutableMultiDict + ) header: Headers = field(default_factory=Headers) - cookie: ImmutableMultiDict = field(default_factory=ImmutableMultiDict) - path: Dict = field(default_factory=dict) + cookie: ImmutableMultiDict[str, Any] = field( + default_factory=ImmutableMultiDict + ) + path: dict[str, Any] = field(default_factory=dict) - def __getitem__(self, location): + def __getitem__(self, location: str) -> Any: return getattr(self, location) @dataclass class Parameters: - query: Dict = field(default_factory=dict) - header: Dict = field(default_factory=dict) - cookie: Dict = field(default_factory=dict) - path: Dict = field(default_factory=dict) + query: dict[str, Any] = field(default_factory=dict) + header: dict[str, Any] = field(default_factory=dict) + cookie: dict[str, Any] = field(default_factory=dict) + path: dict[str, Any] = field(default_factory=dict) @dataclass class RequestValidationResult(BaseValidationResult): - body: Optional[str] = None + body: str | None = None parameters: Parameters = field(default_factory=Parameters) - security: Optional[Dict[str, str]] = None + security: dict[str, str] | None = None diff --git a/openapi_core/validation/request/exceptions.py b/openapi_core/validation/request/exceptions.py index 18d9b37f..7485ae53 100644 --- a/openapi_core/validation/request/exceptions.py +++ b/openapi_core/validation/request/exceptions.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import Iterable from openapi_core.exceptions import OpenAPIError from openapi_core.validation.request.datatypes import Parameters @@ -9,7 +9,7 @@ @dataclass class ParametersError(Exception): parameters: Parameters - context: List[Exception] + context: Iterable[Exception] class OpenAPIRequestBodyError(OpenAPIError): @@ -24,7 +24,7 @@ class MissingRequestBodyError(OpenAPIRequestBodyError): class MissingRequestBody(MissingRequestBodyError): request: Request - def __str__(self): + def __str__(self) -> str: return "Missing request body" @@ -32,5 +32,5 @@ def __str__(self): class MissingRequiredRequestBody(MissingRequestBodyError): request: Request - def __str__(self): + def __str__(self) -> str: return "Missing required request body" diff --git a/openapi_core/validation/request/protocols.py b/openapi_core/validation/request/protocols.py index e1cec219..1a880eb9 100644 --- a/openapi_core/validation/request/protocols.py +++ b/openapi_core/validation/request/protocols.py @@ -1,5 +1,6 @@ """OpenAPI core validation request protocols module""" from typing import TYPE_CHECKING +from typing import Optional if TYPE_CHECKING: from typing_extensions import Protocol @@ -45,12 +46,27 @@ class Request(Protocol): the mimetype would be "text/html". """ - host_url: str - path: str - method: str parameters: RequestParameters - body: str - mimetype: str + + @property + def host_url(self) -> str: + ... + + @property + def path(self) -> str: + ... + + @property + def method(self) -> str: + ... + + @property + def body(self) -> Optional[str]: + ... + + @property + def mimetype(self) -> str: + ... @runtime_checkable @@ -66,4 +82,6 @@ class SupportsPathPattern(Protocol): /api/v1/pets/{pet_id} """ - path_pattern: str + @property + def path_pattern(self) -> str: + ... diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 0bdd125b..c0298fb2 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -1,18 +1,29 @@ """OpenAPI core validation request validators module""" import warnings +from typing import Any +from typing import Dict +from typing import Optional from openapi_core.casting.schemas import schema_casters_factory from openapi_core.casting.schemas.exceptions import CastError +from openapi_core.casting.schemas.factories import SchemaCastersFactory from openapi_core.deserializing.exceptions import DeserializeError from openapi_core.deserializing.media_types import ( media_type_deserializers_factory, ) +from openapi_core.deserializing.media_types.factories import ( + MediaTypeDeserializersFactory, +) from openapi_core.deserializing.parameters import ( parameter_deserializers_factory, ) -from openapi_core.schema.parameters import iter_params +from openapi_core.deserializing.parameters.factories import ( + ParameterDeserializersFactory, +) from openapi_core.security import security_provider_factory from openapi_core.security.exceptions import SecurityError +from openapi_core.security.factories import SecurityProviderFactory +from openapi_core.spec.paths import Spec from openapi_core.templating.media_types.exceptions import MediaTypeFinderError from openapi_core.templating.paths.exceptions import PathError from openapi_core.unmarshalling.schemas.enums import UnmarshalContext @@ -21,6 +32,7 @@ from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) +from openapi_core.util import chainiters from openapi_core.validation.exceptions import InvalidSecurity from openapi_core.validation.exceptions import MissingParameter from openapi_core.validation.exceptions import MissingRequiredParameter @@ -31,17 +43,18 @@ MissingRequiredRequestBody, ) from openapi_core.validation.request.exceptions import ParametersError +from openapi_core.validation.request.protocols import Request from openapi_core.validation.validators import BaseValidator class BaseRequestValidator(BaseValidator): def __init__( self, - schema_unmarshallers_factory, - schema_casters_factory=schema_casters_factory, - parameter_deserializers_factory=parameter_deserializers_factory, - media_type_deserializers_factory=media_type_deserializers_factory, - security_provider_factory=security_provider_factory, + schema_unmarshallers_factory: SchemaUnmarshallersFactory, + schema_casters_factory: SchemaCastersFactory = schema_casters_factory, + parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, + media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, + security_provider_factory: SecurityProviderFactory = security_provider_factory, ): super().__init__( schema_unmarshallers_factory, @@ -53,20 +66,22 @@ def __init__( def validate( self, - spec, - request, - base_url=None, - ): + spec: Spec, + request: Request, + base_url: Optional[str] = None, + ) -> RequestValidationResult: raise NotImplementedError - def _get_parameters(self, request, path, operation): + def _get_parameters( + self, request: Request, path: Spec, operation: Spec + ) -> Parameters: operation_params = operation.get("parameters", []) path_params = path.get("parameters", []) errors = [] seen = set() parameters = Parameters() - params_iter = iter_params(operation_params, path_params) + params_iter = chainiters(operation_params, path_params) for param in params_iter: param_name = param["name"] param_location = param["in"] @@ -97,7 +112,7 @@ def _get_parameters(self, request, path, operation): return parameters - def _get_parameter(self, param, request): + def _get_parameter(self, param: Spec, request: Request) -> Any: name = param["name"] deprecated = param.getkey("deprecated", False) if deprecated: @@ -116,7 +131,9 @@ def _get_parameter(self, param, request): raise MissingRequiredParameter(name) raise MissingParameter(name) - def _get_security(self, spec, request, operation): + def _get_security( + self, spec: Spec, request: Request, operation: Spec + ) -> Optional[Dict[str, str]]: security = None if "security" in spec: security = spec / "security" @@ -139,7 +156,9 @@ def _get_security(self, spec, request, operation): raise InvalidSecurity - def _get_security_value(self, spec, scheme_name, request): + def _get_security_value( + self, spec: Spec, scheme_name: str, request: Request + ) -> Any: security_schemes = spec / "components#securitySchemes" if scheme_name not in security_schemes: return @@ -147,7 +166,7 @@ def _get_security_value(self, spec, scheme_name, request): security_provider = self.security_provider_factory.create(scheme) return security_provider(request) - def _get_body(self, request, operation): + def _get_body(self, request: Request, operation: Spec) -> Any: if "requestBody" not in operation: return None @@ -168,7 +187,7 @@ def _get_body(self, request, operation): return body - def _get_body_value(self, request_body, request): + def _get_body_value(self, request_body: Spec, request: Request) -> Any: if not request.body: if request_body.getkey("required", False): raise MissingRequiredRequestBody(request) @@ -179,10 +198,10 @@ def _get_body_value(self, request_body, request): class RequestParametersValidator(BaseRequestValidator): def validate( self, - spec, - request, - base_url=None, - ): + spec: Spec, + request: Request, + base_url: Optional[str] = None, + ) -> RequestValidationResult: try: path, operation, _, path_result, _ = self._find_path( spec, request, base_url=base_url @@ -211,10 +230,10 @@ def validate( class RequestBodyValidator(BaseRequestValidator): def validate( self, - spec, - request, - base_url=None, - ): + spec: Spec, + request: Request, + base_url: Optional[str] = None, + ) -> RequestValidationResult: try: _, operation, _, _, _ = self._find_path( spec, request, base_url=base_url @@ -249,10 +268,10 @@ def validate( class RequestSecurityValidator(BaseRequestValidator): def validate( self, - spec, - request, - base_url=None, - ): + spec: Spec, + request: Request, + base_url: Optional[str] = None, + ) -> RequestValidationResult: try: _, operation, _, _, _ = self._find_path( spec, request, base_url=base_url @@ -274,10 +293,10 @@ def validate( class RequestValidator(BaseRequestValidator): def validate( self, - spec, - request, - base_url=None, - ): + spec: Spec, + request: Request, + base_url: Optional[str] = None, + ) -> RequestValidationResult: try: path, operation, _, path_result, _ = self._find_path( spec, request, base_url=base_url @@ -321,7 +340,7 @@ def validate( else: body_errors = [] - errors = params_errors + body_errors + errors = list(chainiters(params_errors, body_errors)) return RequestValidationResult( errors=errors, body=body, diff --git a/openapi_core/validation/response/datatypes.py b/openapi_core/validation/response/datatypes.py index abcd4d5a..f820936b 100644 --- a/openapi_core/validation/response/datatypes.py +++ b/openapi_core/validation/response/datatypes.py @@ -1,6 +1,7 @@ """OpenAPI core validation response datatypes module""" from dataclasses import dataclass from dataclasses import field +from typing import Any from typing import Dict from typing import Optional @@ -10,4 +11,4 @@ @dataclass class ResponseValidationResult(BaseValidationResult): data: Optional[str] = None - headers: Dict = field(default_factory=dict) + headers: Dict[str, Any] = field(default_factory=dict) diff --git a/openapi_core/validation/response/exceptions.py b/openapi_core/validation/response/exceptions.py index 5808f23b..277556c6 100644 --- a/openapi_core/validation/response/exceptions.py +++ b/openapi_core/validation/response/exceptions.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any from typing import Dict -from typing import List +from typing import Iterable from openapi_core.exceptions import OpenAPIError from openapi_core.validation.response.protocols import Response @@ -10,7 +10,7 @@ @dataclass class HeadersError(Exception): headers: Dict[str, Any] - context: List[Exception] + context: Iterable[OpenAPIError] class OpenAPIResponseError(OpenAPIError): @@ -21,5 +21,5 @@ class OpenAPIResponseError(OpenAPIError): class MissingResponseContent(OpenAPIResponseError): response: Response - def __str__(self): + def __str__(self) -> str: return "Missing response content" diff --git a/openapi_core/validation/response/protocols.py b/openapi_core/validation/response/protocols.py index 1a9841ac..2e67ecdb 100644 --- a/openapi_core/validation/response/protocols.py +++ b/openapi_core/validation/response/protocols.py @@ -30,7 +30,18 @@ class Response(Protocol): Lowercase content type without charset. """ - data: str - status_code: int - mimetype: str - headers: Headers + @property + def data(self) -> str: + ... + + @property + def status_code(self) -> int: + ... + + @property + def mimetype(self) -> str: + ... + + @property + def headers(self) -> Headers: + ... diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index 77c99ce9..0e735c82 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -1,8 +1,14 @@ """OpenAPI core validation response validators module""" import warnings +from typing import Any +from typing import Dict +from typing import List +from typing import Optional from openapi_core.casting.schemas.exceptions import CastError from openapi_core.deserializing.exceptions import DeserializeError +from openapi_core.exceptions import OpenAPIError +from openapi_core.spec import Spec from openapi_core.templating.media_types.exceptions import MediaTypeFinderError from openapi_core.templating.paths.exceptions import PathError from openapi_core.templating.responses.exceptions import ResponseFinderError @@ -12,37 +18,48 @@ from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) +from openapi_core.util import chainiters from openapi_core.validation.exceptions import MissingHeader from openapi_core.validation.exceptions import MissingRequiredHeader +from openapi_core.validation.request.protocols import Request from openapi_core.validation.response.datatypes import ResponseValidationResult from openapi_core.validation.response.exceptions import HeadersError from openapi_core.validation.response.exceptions import MissingResponseContent +from openapi_core.validation.response.protocols import Response from openapi_core.validation.validators import BaseValidator class BaseResponseValidator(BaseValidator): def validate( self, - spec, - request, - response, - base_url=None, - ): + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + ) -> ResponseValidationResult: raise NotImplementedError - def _find_operation_response(self, spec, request, response, base_url=None): + def _find_operation_response( + self, + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + ) -> Spec: _, operation, _, _, _ = self._find_path( spec, request, base_url=base_url ) return self._get_operation_response(operation, response) - def _get_operation_response(self, operation, response): + def _get_operation_response( + self, operation: Spec, response: Response + ) -> Spec: from openapi_core.templating.responses.finders import ResponseFinder finder = ResponseFinder(operation / "responses") return finder.find(str(response.status_code)) - def _get_data(self, response, operation_response): + def _get_data(self, response: Response, operation_response: Spec) -> Any: if "content" not in operation_response: return None @@ -61,20 +78,22 @@ def _get_data(self, response, operation_response): return data - def _get_data_value(self, response): + def _get_data_value(self, response: Response) -> Any: if not response.data: raise MissingResponseContent(response) return response.data - def _get_headers(self, response, operation_response): + def _get_headers( + self, response: Response, operation_response: Spec + ) -> Dict[str, Any]: if "headers" not in operation_response: return {} headers = operation_response / "headers" - errors = [] - validated = {} + errors: List[OpenAPIError] = [] + validated: Dict[str, Any] = {} for name, header in list(headers.items()): # ignore Content-Type header if name.lower() == "content-type": @@ -96,11 +115,11 @@ def _get_headers(self, response, operation_response): validated[name] = value if errors: - raise HeadersError(context=errors, headers=validated) + raise HeadersError(context=iter(errors), headers=validated) return validated - def _get_header(self, name, header, response): + def _get_header(self, name: str, header: Spec, response: Response) -> Any: deprecated = header.getkey("deprecated", False) if deprecated: warnings.warn( @@ -122,11 +141,11 @@ def _get_header(self, name, header, response): class ResponseDataValidator(BaseResponseValidator): def validate( self, - spec, - request, - response, - base_url=None, - ): + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + ) -> ResponseValidationResult: try: operation_response = self._find_operation_response( spec, @@ -162,11 +181,11 @@ def validate( class ResponseHeadersValidator(BaseResponseValidator): def validate( self, - spec, - request, - response, - base_url=None, - ): + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + ) -> ResponseValidationResult: try: operation_response = self._find_operation_response( spec, @@ -195,11 +214,11 @@ def validate( class ResponseValidator(BaseResponseValidator): def validate( self, - spec, - request, - response, - base_url=None, - ): + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + ) -> ResponseValidationResult: try: operation_response = self._find_operation_response( spec, @@ -234,7 +253,7 @@ def validate( else: headers_errors = [] - errors = data_errors + headers_errors + errors = list(chainiters(data_errors, headers_errors)) return ResponseValidationResult( errors=errors, data=data, diff --git a/openapi_core/validation/shortcuts.py b/openapi_core/validation/shortcuts.py index 5818d38f..7eaed534 100644 --- a/openapi_core/validation/shortcuts.py +++ b/openapi_core/validation/shortcuts.py @@ -1,23 +1,35 @@ """OpenAPI core validation shortcuts module""" +from typing import Optional + +from openapi_core.spec import Spec from openapi_core.validation.request import openapi_request_validator +from openapi_core.validation.request.datatypes import RequestValidationResult +from openapi_core.validation.request.protocols import Request +from openapi_core.validation.request.validators import RequestValidator from openapi_core.validation.response import openapi_response_validator +from openapi_core.validation.response.datatypes import ResponseValidationResult +from openapi_core.validation.response.protocols import Response +from openapi_core.validation.response.validators import ResponseValidator def validate_request( - spec, request, base_url=None, validator=openapi_request_validator -): + spec: Spec, + request: Request, + base_url: Optional[str] = None, + validator: RequestValidator = openapi_request_validator, +) -> RequestValidationResult: result = validator.validate(spec, request, base_url=base_url) result.raise_for_errors() return result def validate_response( - spec, - request, - response, - base_url=None, - validator=openapi_response_validator, -): + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + validator: ResponseValidator = openapi_response_validator, +) -> ResponseValidationResult: result = validator.validate(spec, request, response, base_url=base_url) result.raise_for_errors() return result diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index 69b34658..5a944e6b 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -1,26 +1,45 @@ """OpenAPI core validation validators module""" +from typing import Any +from typing import Dict +from typing import Optional +from typing import Union from urllib.parse import urljoin +from werkzeug.datastructures import Headers + from openapi_core.casting.schemas import schema_casters_factory +from openapi_core.casting.schemas.factories import SchemaCastersFactory from openapi_core.deserializing.media_types import ( media_type_deserializers_factory, ) +from openapi_core.deserializing.media_types.factories import ( + MediaTypeDeserializersFactory, +) from openapi_core.deserializing.parameters import ( parameter_deserializers_factory, ) +from openapi_core.deserializing.parameters.factories import ( + ParameterDeserializersFactory, +) from openapi_core.schema.parameters import get_value +from openapi_core.spec import Spec +from openapi_core.templating.media_types.datatypes import MediaType +from openapi_core.templating.paths.datatypes import ServerOperationPath from openapi_core.templating.paths.finders import PathFinder -from openapi_core.unmarshalling.schemas.util import build_format_checker +from openapi_core.unmarshalling.schemas.factories import ( + SchemaUnmarshallersFactory, +) +from openapi_core.validation.request.protocols import Request from openapi_core.validation.request.protocols import SupportsPathPattern class BaseValidator: def __init__( self, - schema_unmarshallers_factory, - schema_casters_factory=schema_casters_factory, - parameter_deserializers_factory=parameter_deserializers_factory, - media_type_deserializers_factory=media_type_deserializers_factory, + schema_unmarshallers_factory: SchemaUnmarshallersFactory, + schema_casters_factory: SchemaCastersFactory = schema_casters_factory, + parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, + media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, ): self.schema_unmarshallers_factory = schema_unmarshallers_factory self.schema_casters_factory = schema_casters_factory @@ -29,36 +48,43 @@ def __init__( media_type_deserializers_factory ) - def _find_path(self, spec, request, base_url=None): + def _find_path( + self, spec: Spec, request: Request, base_url: Optional[str] = None + ) -> ServerOperationPath: path_finder = PathFinder(spec, base_url=base_url) path_pattern = getattr(request, "path_pattern", None) return path_finder.find( request.method, request.host_url, request.path, path_pattern ) - def _get_media_type(self, content, mimetype): + def _get_media_type(self, content: Spec, mimetype: str) -> MediaType: from openapi_core.templating.media_types.finders import MediaTypeFinder finder = MediaTypeFinder(content) return finder.find(mimetype) - def _deserialise_data(self, mimetype, value): + def _deserialise_data(self, mimetype: str, value: Any) -> Any: deserializer = self.media_type_deserializers_factory.create(mimetype) return deserializer(value) - def _deserialise_parameter(self, param, value): + def _deserialise_parameter(self, param: Spec, value: Any) -> Any: deserializer = self.parameter_deserializers_factory.create(param) return deserializer(value) - def _cast(self, schema, value): + def _cast(self, schema: Spec, value: Any) -> Any: caster = self.schema_casters_factory.create(schema) return caster(value) - def _unmarshal(self, schema, value): + def _unmarshal(self, schema: Spec, value: Any) -> Any: unmarshaller = self.schema_unmarshallers_factory.create(schema) return unmarshaller(value) - def _get_param_or_header_value(self, param_or_header, location, name=None): + def _get_param_or_header_value( + self, + param_or_header: Spec, + location: Union[Headers, Dict[str, Any]], + name: Optional[str] = None, + ) -> Any: try: raw_value = get_value(param_or_header, location, name=name) except KeyError: diff --git a/poetry.lock b/poetry.lock index b3f3f788..9cf95c09 100644 --- a/poetry.lock +++ b/poetry.lock @@ -393,6 +393,25 @@ category = "main" optional = false python-versions = ">=3.5" +[[package]] +name = "mypy" +version = "0.971" +description = "Optional static typing for Python" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +mypy-extensions = ">=0.4.3" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} +typing-extensions = ">=3.10" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "0.4.3" @@ -855,6 +874,25 @@ category = "dev" optional = false python-versions = ">=3.6" +[[package]] +name = "types-requests" +version = "2.28.9" +description = "Typing stubs for requests" +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +types-urllib3 = "<1.27" + +[[package]] +name = "types-urllib3" +version = "1.26.23" +description = "Typing stubs for urllib3" +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "typing-extensions" version = "4.3.0" @@ -941,7 +979,7 @@ requests = ["requests"] [metadata] lock-version = "1.1" python-versions = "^3.7.0" -content-hash = "4c9aa4db8e6d6ee76a8dabcb82b1d1c6f786c6b5c36023fdb66707add4706cd5" +content-hash = "ffa07e7b70aec4ff76eba4855fbeb2e01b1eabe24f1967fefa25dbc184f0d9e4" [metadata.files] alabaster = [] @@ -978,6 +1016,7 @@ jsonschema = [] markupsafe = [] mccabe = [] more-itertools = [] +mypy = [] mypy-extensions = [] nodeenv = [] openapi-schema-validator = [] @@ -1018,6 +1057,8 @@ strict-rfc3339 = [] toml = [] tomli = [] typed-ast = [] +types-requests = [] +types-urllib3 = [] typing-extensions = [] urllib3 = [] virtualenv = [] diff --git a/pyproject.toml b/pyproject.toml index e471bb04..4e352c98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,25 @@ source =["openapi_core"] [tool.coverage.xml] output = "reports/coverage.xml" +[tool.mypy] +files = "openapi_core" +strict = true + +[[tool.mypy.overrides]] +module = [ + "django.*", + "falcon.*", + "isodate.*", + "jsonschema.*", + "more_itertools.*", + "openapi_spec_validator.*", + "openapi_schema_validator.*", + "parse.*", + "requests.*", + "werkzeug.*", +] +ignore_missing_imports = true + [tool.poetry] name = "openapi-core" version = "0.15.0a2" @@ -69,6 +88,7 @@ sphinx = "^4.0.2" sphinx-rtd-theme = "^0.5.2" strict-rfc3339 = "^0.7" webob = "*" +mypy = "^0.971" [tool.pytest.ini_options] addopts = """ diff --git a/tests/unit/contrib/django/test_django.py b/tests/unit/contrib/django/test_django.py index 3c33985f..8fc5ca02 100644 --- a/tests/unit/contrib/django/test_django.py +++ b/tests/unit/contrib/django/test_django.py @@ -1,5 +1,6 @@ import pytest from werkzeug.datastructures import Headers +from werkzeug.datastructures import ImmutableMultiDict from openapi_core.contrib.django import DjangoOpenAPIRequest from openapi_core.contrib.django import DjangoOpenAPIResponse @@ -62,12 +63,17 @@ def create(content=b"", status_code=None): class TestDjangoOpenAPIRequest(BaseTestDjango): def test_no_resolver(self, request_factory): - request = request_factory.get("/admin/") + data = {"test1": "test2"} + request = request_factory.get("/admin/", data) openapi_request = DjangoOpenAPIRequest(request) path = {} - query = {} + query = ImmutableMultiDict( + [ + ("test1", "test2"), + ] + ) headers = Headers( { "Cookie": "", @@ -83,7 +89,7 @@ def test_no_resolver(self, request_factory): assert openapi_request.method == request.method.lower() assert openapi_request.host_url == request._current_scheme_host assert openapi_request.path == request.path - assert openapi_request.body == request.body + assert openapi_request.body == "" assert openapi_request.mimetype == request.content_type def test_simple(self, request_factory): @@ -111,7 +117,7 @@ def test_simple(self, request_factory): assert openapi_request.method == request.method.lower() assert openapi_request.host_url == request._current_scheme_host assert openapi_request.path == request.path - assert openapi_request.body == request.body + assert openapi_request.body == "" assert openapi_request.mimetype == request.content_type def test_url_rule(self, request_factory): @@ -142,7 +148,7 @@ def test_url_rule(self, request_factory): assert openapi_request.host_url == request._current_scheme_host assert openapi_request.path == request.path assert openapi_request.path_pattern == "/admin/auth/group/{object_id}/" - assert openapi_request.body == request.body + assert openapi_request.body == "" assert openapi_request.mimetype == request.content_type def test_url_regexp_pattern(self, request_factory): @@ -170,7 +176,7 @@ def test_url_regexp_pattern(self, request_factory): assert openapi_request.method == request.method.lower() assert openapi_request.host_url == request._current_scheme_host assert openapi_request.path == "/test/test-regexp/" - assert openapi_request.body == request.body + assert openapi_request.body == "" assert openapi_request.mimetype == request.content_type @@ -181,15 +187,16 @@ def test_stream_response(self, response_factory): openapi_response = DjangoOpenAPIResponse(response) - assert openapi_response.data == b"foo\nbar\nbaz\n" + assert openapi_response.data == "foo\nbar\nbaz\n" assert openapi_response.status_code == response.status_code assert openapi_response.mimetype == response["Content-Type"] def test_redirect_response(self, response_factory): - response = response_factory("/redirected/", status_code=302) + data = "/redirected/" + response = response_factory(data, status_code=302) openapi_response = DjangoOpenAPIResponse(response) - assert openapi_response.data == response.content + assert openapi_response.data == data assert openapi_response.status_code == response.status_code assert openapi_response.mimetype == response["Content-Type"] diff --git a/tests/unit/contrib/flask/test_flask_requests.py b/tests/unit/contrib/flask/test_flask_requests.py index a3744c80..08d7828a 100644 --- a/tests/unit/contrib/flask/test_flask_requests.py +++ b/tests/unit/contrib/flask/test_flask_requests.py @@ -23,10 +23,10 @@ def test_simple(self, request_factory, request): header=headers, cookie=cookies, ) - assert openapi_request.method == request.method.lower() + assert openapi_request.method == "get" assert openapi_request.host_url == request.host_url assert openapi_request.path == request.path - assert openapi_request.body == request.data + assert openapi_request.body == "" assert openapi_request.mimetype == request.mimetype def test_multiple_values(self, request_factory, request): @@ -51,10 +51,10 @@ def test_multiple_values(self, request_factory, request): header=headers, cookie=cookies, ) - assert openapi_request.method == request.method.lower() + assert openapi_request.method == "get" assert openapi_request.host_url == request.host_url assert openapi_request.path == request.path - assert openapi_request.body == request.data + assert openapi_request.body == "" assert openapi_request.mimetype == request.mimetype def test_url_rule(self, request_factory, request): @@ -72,9 +72,9 @@ def test_url_rule(self, request_factory, request): header=headers, cookie=cookies, ) - assert openapi_request.method == request.method.lower() + assert openapi_request.method == "get" assert openapi_request.host_url == request.host_url assert openapi_request.path == request.path assert openapi_request.path_pattern == "/browse/{id}/" - assert openapi_request.body == request.data + assert openapi_request.body == "" assert openapi_request.mimetype == request.mimetype diff --git a/tests/unit/contrib/flask/test_flask_responses.py b/tests/unit/contrib/flask/test_flask_responses.py index 5b2fd1a7..6b9c30f6 100644 --- a/tests/unit/contrib/flask/test_flask_responses.py +++ b/tests/unit/contrib/flask/test_flask_responses.py @@ -3,10 +3,12 @@ class TestFlaskOpenAPIResponse: def test_invalid_server(self, response_factory): - response = response_factory("Not Found", status_code=404) + data = "Not Found" + status_code = 404 + response = response_factory(data, status_code=status_code) openapi_response = FlaskOpenAPIResponse(response) - assert openapi_response.data == response.data - assert openapi_response.status_code == response._status_code + assert openapi_response.data == data + assert openapi_response.status_code == status_code assert openapi_response.mimetype == response.mimetype diff --git a/tests/unit/contrib/requests/test_requests_responses.py b/tests/unit/contrib/requests/test_requests_responses.py index 7fa17991..62da483f 100644 --- a/tests/unit/contrib/requests/test_requests_responses.py +++ b/tests/unit/contrib/requests/test_requests_responses.py @@ -3,11 +3,13 @@ class TestRequestsOpenAPIResponse: def test_invalid_server(self, response_factory): - response = response_factory("Not Found", status_code=404) + data = "Not Found" + status_code = 404 + response = response_factory(data, status_code=status_code) openapi_response = RequestsOpenAPIResponse(response) - assert openapi_response.data == response.content - assert openapi_response.status_code == response.status_code + assert openapi_response.data == data + assert openapi_response.status_code == status_code mimetype = response.headers.get("Content-Type") assert openapi_response.mimetype == mimetype