Skip to content

Static types with mypy #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ jobs:
PYTEST_ADDOPTS: "--color=yes"
run: poetry run pytest

- name: Static type check
run: poetry run mypy

- name: Upload coverage
uses: codecov/codecov-action@v1

Expand Down
28 changes: 19 additions & 9 deletions openapi_core/casting/schemas/casters.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,59 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import List

from openapi_core.casting.schemas.datatypes import CasterCallable
from openapi_core.casting.schemas.exceptions import CastError
from openapi_core.spec import Spec

if TYPE_CHECKING:
from openapi_core.casting.schemas.factories import SchemaCastersFactory


class BaseSchemaCaster:
def __init__(self, schema):
def __init__(self, schema: Spec):
self.schema = schema

def __call__(self, value):
def __call__(self, value: Any) -> Any:
if value is None:
return value

return self.cast(value)

def cast(self, value):
def cast(self, value: Any) -> Any:
raise NotImplementedError


class CallableSchemaCaster(BaseSchemaCaster):
def __init__(self, schema, caster_callable):
def __init__(self, schema: Spec, caster_callable: CasterCallable):
super().__init__(schema)
self.caster_callable = caster_callable

def cast(self, value):
def cast(self, value: Any) -> Any:
try:
return self.caster_callable(value)
except (ValueError, TypeError):
raise CastError(value, self.schema["type"])


class DummyCaster(BaseSchemaCaster):
def cast(self, value):
def cast(self, value: Any) -> Any:
return value


class ComplexCaster(BaseSchemaCaster):
def __init__(self, schema, casters_factory):
def __init__(self, schema: Spec, casters_factory: "SchemaCastersFactory"):
super().__init__(schema)
self.casters_factory = casters_factory


class ArrayCaster(ComplexCaster):
@property
def items_caster(self):
def items_caster(self) -> BaseSchemaCaster:
return self.casters_factory.create(self.schema / "items")

def cast(self, value):
def cast(self, value: Any) -> List[Any]:
try:
return list(map(self.items_caster, value))
except (ValueError, TypeError):
Expand Down
4 changes: 4 additions & 0 deletions openapi_core/casting/schemas/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Any
from typing import Callable

CasterCallable = Callable[[Any], Any]
2 changes: 1 addition & 1 deletion openapi_core/casting/schemas/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ class CastError(OpenAPIError):
value: str
type: str

def __str__(self):
def __str__(self) -> str:
return f"Failed to cast value to {self.type} type: {self.value}"
9 changes: 7 additions & 2 deletions openapi_core/casting/schemas/factories.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Dict

from openapi_core.casting.schemas.casters import ArrayCaster
from openapi_core.casting.schemas.casters import BaseSchemaCaster
from openapi_core.casting.schemas.casters import CallableSchemaCaster
from openapi_core.casting.schemas.casters import DummyCaster
from openapi_core.casting.schemas.datatypes import CasterCallable
from openapi_core.spec import Spec
from openapi_core.util import forcebool


Expand All @@ -11,7 +16,7 @@ class SchemaCastersFactory:
"object",
"any",
]
PRIMITIVE_CASTERS = {
PRIMITIVE_CASTERS: Dict[str, CasterCallable] = {
"integer": int,
"number": float,
"boolean": forcebool,
Expand All @@ -20,7 +25,7 @@ class SchemaCastersFactory:
"array": ArrayCaster,
}

def create(self, schema):
def create(self, schema: Spec) -> BaseSchemaCaster:
schema_type = schema.getkey("type", "any")

if schema_type in self.DUMMY_CASTERS:
Expand Down
23 changes: 18 additions & 5 deletions openapi_core/contrib/django/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
"""OpenAPI core contrib django handlers module"""
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Type

from django.http import JsonResponse
from django.http.request import HttpRequest
from django.http.response import HttpResponse

from openapi_core.templating.media_types.exceptions import MediaTypeNotFound
from openapi_core.templating.paths.exceptions import OperationNotFound
Expand All @@ -11,7 +19,7 @@

class DjangoOpenAPIErrorsHandler:

OPENAPI_ERROR_STATUS = {
OPENAPI_ERROR_STATUS: Dict[Type[Exception], int] = {
MissingRequiredParameter: 400,
ServerNotFound: 400,
InvalidSecurity: 403,
Expand All @@ -21,7 +29,12 @@ class DjangoOpenAPIErrorsHandler:
}

@classmethod
def handle(cls, errors, req, resp=None):
def handle(
cls,
errors: Iterable[Exception],
req: HttpRequest,
resp: Optional[HttpResponse] = None,
) -> JsonResponse:
data_errors = [cls.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
Expand All @@ -30,13 +43,13 @@ def handle(cls, errors, req, resp=None):
return JsonResponse(data, status=data_error_max["status"])

@classmethod
def format_openapi_error(cls, error):
def format_openapi_error(cls, error: Exception) -> Dict[str, Any]:
return {
"title": str(error),
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
"class": str(type(error)),
}

@classmethod
def get_error_status(cls, error):
return error["status"]
def get_error_status(cls, error: Dict[str, Any]) -> str:
return str(error["status"])
32 changes: 26 additions & 6 deletions openapi_core/contrib/django/middlewares.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
"""OpenAPI core contrib django middlewares module"""
from typing import Callable

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.http import JsonResponse
from django.http.request import HttpRequest
from django.http.response import HttpResponse

from openapi_core.contrib.django.handlers import DjangoOpenAPIErrorsHandler
from openapi_core.contrib.django.requests import DjangoOpenAPIRequest
from openapi_core.contrib.django.responses import DjangoOpenAPIResponse
from openapi_core.validation.processors import OpenAPIProcessor
from openapi_core.validation.request import openapi_request_validator
from openapi_core.validation.request.datatypes import RequestValidationResult
from openapi_core.validation.request.protocols import Request
from openapi_core.validation.response import openapi_response_validator
from openapi_core.validation.response.datatypes import ResponseValidationResult
from openapi_core.validation.response.protocols import Response


class DjangoOpenAPIMiddleware:
Expand All @@ -16,7 +25,7 @@ class DjangoOpenAPIMiddleware:
response_class = DjangoOpenAPIResponse
errors_handler = DjangoOpenAPIErrorsHandler()

def __init__(self, get_response):
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response

if not hasattr(settings, "OPENAPI_SPEC"):
Expand All @@ -26,7 +35,7 @@ def __init__(self, get_response):
openapi_request_validator, openapi_response_validator
)

def __call__(self, request):
def __call__(self, request: HttpRequest) -> HttpResponse:
openapi_request = self._get_openapi_request(request)
req_result = self.validation_processor.process_request(
settings.OPENAPI_SPEC, openapi_request
Expand All @@ -46,14 +55,25 @@ def __call__(self, request):

return response

def _handle_request_errors(self, request_result, req):
def _handle_request_errors(
self, request_result: RequestValidationResult, req: HttpRequest
) -> JsonResponse:
return self.errors_handler.handle(request_result.errors, req, None)

def _handle_response_errors(self, response_result, req, resp):
def _handle_response_errors(
self,
response_result: ResponseValidationResult,
req: HttpRequest,
resp: HttpResponse,
) -> JsonResponse:
return self.errors_handler.handle(response_result.errors, req, resp)

def _get_openapi_request(self, request):
def _get_openapi_request(
self, request: HttpRequest
) -> DjangoOpenAPIRequest:
return self.request_class(request)

def _get_openapi_response(self, response):
def _get_openapi_response(
self, response: HttpResponse
) -> DjangoOpenAPIResponse:
return self.response_class(response)
36 changes: 23 additions & 13 deletions openapi_core/contrib/django/requests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""OpenAPI core contrib django requests module"""
import re
from urllib.parse import urljoin
from typing import Optional

from django.http.request import HttpRequest
from werkzeug.datastructures import Headers
from werkzeug.datastructures import ImmutableMultiDict

Expand All @@ -24,28 +25,33 @@ class DjangoOpenAPIRequest:

path_regex = re.compile(PATH_PARAMETER_PATTERN)

def __init__(self, request):
def __init__(self, request: HttpRequest):
self.request = request

self.parameters = RequestParameters(
path=self.request.resolver_match
path = (
self.request.resolver_match
and self.request.resolver_match.kwargs
or {},
or {}
)
self.parameters = RequestParameters(
path=path,
query=ImmutableMultiDict(self.request.GET),
header=Headers(self.request.headers.items()),
cookie=ImmutableMultiDict(dict(self.request.COOKIES)),
)

@property
def host_url(self):
def host_url(self) -> str:
assert isinstance(self.request._current_scheme_host, str)
return self.request._current_scheme_host

@property
def path(self):
def path(self) -> str:
assert isinstance(self.request.path, str)
return self.request.path

@property
def path_pattern(self):
def path_pattern(self) -> Optional[str]:
if self.request.resolver_match is None:
return None

Expand All @@ -58,13 +64,17 @@ def path_pattern(self):
return "/" + route

@property
def method(self):
def method(self) -> str:
if self.request.method is None:
return ""
assert isinstance(self.request.method, str)
return self.request.method.lower()

@property
def body(self):
return self.request.body
def body(self) -> str:
assert isinstance(self.request.body, bytes)
return self.request.body.decode("utf-8")

@property
def mimetype(self):
return self.request.content_type
def mimetype(self) -> str:
return self.request.content_type or ""
19 changes: 12 additions & 7 deletions openapi_core/contrib/django/responses.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
"""OpenAPI core contrib django responses module"""
from django.http.response import HttpResponse
from werkzeug.datastructures import Headers


class DjangoOpenAPIResponse:
def __init__(self, response):
def __init__(self, response: HttpResponse):
self.response = response

@property
def data(self):
return self.response.content
def data(self) -> str:
assert isinstance(self.response.content, bytes)
return self.response.content.decode("utf-8")

@property
def status_code(self):
def status_code(self) -> int:
assert isinstance(self.response.status_code, int)
return self.response.status_code

@property
def headers(self):
def headers(self) -> Headers:
return Headers(self.response.headers.items())

@property
def mimetype(self):
return self.response["Content-Type"]
def mimetype(self) -> str:
content_type = self.response.get("Content-Type", "")
assert isinstance(content_type, str)
return content_type
Loading