Skip to content

Unmarshaller format refactor #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/customizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Here's how you could add support for a ``usdate`` format that handles dates of t
def validate(self, value) -> bool:
return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value))

def unmarshal(self, value):
def format(self, value):
return datetime.strptime(value, "%m/%d/%y").date


Expand Down
8 changes: 0 additions & 8 deletions openapi_core/unmarshalling/schemas/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,6 @@ def create(
klass = self.UNMARSHALLERS[schema_type]
return klass(schema, validator, formatter)

def get_formatter(
self, type_format: str, default_formatters: FormattersDict
) -> Optional[Formatter]:
try:
return self.custom_formatters[type_format]
except KeyError:
return default_formatters.get(type_format)

def get_validator(self, schema: Spec) -> Validator:
resolver = schema.accessor.resolver # type: ignore
custom_format_checks = {
Expand Down
40 changes: 35 additions & 5 deletions openapi_core/unmarshalling/schemas/formatters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any
from typing import Callable
from typing import Optional
Expand All @@ -8,20 +9,49 @@ class Formatter:
def validate(self, value: Any) -> bool:
return True

def unmarshal(self, value: Any) -> Any:
def format(self, value: Any) -> Any:
return value

def __getattribute__(self, name: str) -> Any:
if name == "unmarshal":
warnings.warn(
"Unmarshal method is deprecated. " "Use format instead.",
DeprecationWarning,
)
return super().__getattribute__("format")
if name == "format":
try:
attr = super().__getattribute__("unmarshal")
except AttributeError:
return super().__getattribute__("format")
else:
warnings.warn(
"Unmarshal method is deprecated. "
"Rename unmarshal method to format instead.",
DeprecationWarning,
)
return attr
return super().__getattribute__(name)

@classmethod
def from_callables(
cls,
validate: Optional[Callable[[Any], Any]] = None,
validate_callable: Optional[Callable[[Any], Any]] = None,
format_callable: Optional[Callable[[Any], Any]] = None,
unmarshal: Optional[Callable[[Any], Any]] = None,
) -> "Formatter":
attrs = {}
if validate is not None:
attrs["validate"] = staticmethod(validate)
if validate_callable is not None:
attrs["validate"] = staticmethod(validate_callable)
if format_callable is not None:
attrs["format"] = staticmethod(format_callable)
if unmarshal is not None:
attrs["unmarshal"] = staticmethod(unmarshal)
warnings.warn(
"Unmarshal parameter is deprecated. "
"Use format_callable instead.",
DeprecationWarning,
)
attrs["format"] = staticmethod(unmarshal)

klass: Type[Formatter] = type("Formatter", (cls,), attrs)
return klass()
107 changes: 54 additions & 53 deletions openapi_core/unmarshalling/schemas/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import cast
Expand Down Expand Up @@ -31,6 +32,7 @@
)
from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue
from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError
from openapi_core.unmarshalling.schemas.exceptions import UnmarshallerError
from openapi_core.unmarshalling.schemas.exceptions import ValidateError
from openapi_core.unmarshalling.schemas.formatters import Formatter
from openapi_core.unmarshalling.schemas.util import format_byte
Expand Down Expand Up @@ -61,24 +63,25 @@ def __init__(
):
self.schema = schema
self.validator = validator
self.format = schema.getkey("format")
self.schema_format = schema.getkey("format")

if formatter is None:
if self.format not in self.FORMATTERS:
raise FormatterNotFoundError(self.format)
self.formatter = self.FORMATTERS[self.format]
if self.schema_format not in self.FORMATTERS:
raise FormatterNotFoundError(self.schema_format)
self.formatter = self.FORMATTERS[self.schema_format]
else:
self.formatter = formatter

def __call__(self, value: Any) -> Any:
if value is None:
return

self.validate(value)

# skip unmarshalling for nullable in OpenAPI 3.0
if value is None and self.schema.getkey("nullable", False):
return value

return self.unmarshal(value)

def _formatter_validate(self, value: Any) -> None:
def _validate_format(self, value: Any) -> None:
result = self.formatter.validate(value)
if not result:
schema_type = self.schema.getkey("type", "any")
Expand All @@ -91,11 +94,14 @@ def validate(self, value: Any) -> None:
schema_type = self.schema.getkey("type", "any")
raise InvalidSchemaValue(value, schema_type, schema_errors=errors)

def unmarshal(self, value: Any) -> Any:
def format(self, value: Any) -> Any:
try:
return self.formatter.unmarshal(value)
except ValueError as exc:
raise InvalidSchemaFormatValue(value, self.format, exc)
return self.formatter.format(value)
except (ValueError, TypeError) as exc:
raise InvalidSchemaFormatValue(value, self.schema_format, exc)

def unmarshal(self, value: Any) -> Any:
return self.format(value)


class StringUnmarshaller(BaseSchemaUnmarshaller):
Expand Down Expand Up @@ -192,10 +198,8 @@ def items_unmarshaller(self) -> "BaseSchemaUnmarshaller":
items_schema = self.schema.get("items", Spec.from_dict({}))
return self.unmarshallers_factory.create(items_schema)

def __call__(self, value: Any) -> Optional[List[Any]]:
value = super().__call__(value)
if value is None and self.schema.getkey("nullable", False):
return None
def unmarshal(self, value: Any) -> Optional[List[Any]]:
value = super().unmarshal(value)
return list(map(self.items_unmarshaller, value))


Expand All @@ -210,38 +214,31 @@ def object_class_factory(self) -> ModelPathFactory:
return ModelPathFactory()

def unmarshal(self, value: Any) -> Any:
properties = self.unmarshal_raw(value)
properties = self.format(value)

fields: Iterable[str] = properties and properties.keys() or []
object_class = self.object_class_factory.create(self.schema, fields)

return object_class(**properties)

def unmarshal_raw(self, value: Any) -> Any:
try:
value = self.formatter.unmarshal(value)
except ValueError as exc:
schema_format = self.schema.getkey("format")
raise InvalidSchemaFormatValue(value, schema_format, exc)
else:
return self._unmarshal_object(value)
def format(self, value: Any) -> Any:
formatted = super().format(value)
return self._unmarshal_properties(formatted)

def _clone(self, schema: Spec) -> "ObjectUnmarshaller":
return cast(
"ObjectUnmarshaller",
self.unmarshallers_factory.create(schema, "object"),
)

def _unmarshal_object(self, value: Any) -> Any:
def _unmarshal_properties(self, value: Any) -> Any:
properties = {}

if "oneOf" in self.schema:
one_of_properties = None
for one_of_schema in self.schema / "oneOf":
try:
unmarshalled = self._clone(one_of_schema).unmarshal_raw(
value
)
unmarshalled = self._clone(one_of_schema).format(value)
except (UnmarshalError, ValueError):
pass
else:
Expand All @@ -259,9 +256,7 @@ def _unmarshal_object(self, value: Any) -> Any:
any_of_properties = None
for any_of_schema in self.schema / "anyOf":
try:
unmarshalled = self._clone(any_of_schema).unmarshal_raw(
value
)
unmarshalled = self._clone(any_of_schema).format(value)
except (UnmarshalError, ValueError):
pass
else:
Expand Down Expand Up @@ -319,21 +314,36 @@ def types_unmarshallers(self) -> List["BaseSchemaUnmarshaller"]:
unmarshaller = partial(self.unmarshallers_factory.create, self.schema)
return list(map(unmarshaller, types))

def unmarshal(self, value: Any) -> Any:
for unmarshaller in self.types_unmarshallers:
@property
def type(self) -> List[str]:
types = self.schema.getkey("type", ["any"])
assert isinstance(types, list)
return types

def _get_unmarshallers_iter(self) -> Iterator["BaseSchemaUnmarshaller"]:
for schema_type in self.type:
yield self.unmarshallers_factory.create(
self.schema, type_override=schema_type
)

def _get_best_unmarshaller(self, value: Any) -> "BaseSchemaUnmarshaller":
for unmarshaller in self._get_unmarshallers_iter():
# validate with validator of formatter (usualy type validator)
try:
unmarshaller._formatter_validate(value)
unmarshaller._validate_format(value)
except ValidateError:
continue
else:
return unmarshaller(value)
return unmarshaller

log.warning("failed to unmarshal multi type")
return value
raise UnmarshallerError("Unmarshaller not found for type(s)")

def unmarshal(self, value: Any) -> Any:
unmarshaller = self._get_best_unmarshaller(value)
return unmarshaller(value)


class AnyUnmarshaller(ComplexUnmarshaller):
class AnyUnmarshaller(MultiTypeUnmarshaller):

SCHEMA_TYPES_ORDER = [
"object",
Expand All @@ -344,6 +354,10 @@ class AnyUnmarshaller(ComplexUnmarshaller):
"string",
]

@property
def type(self) -> List[str]:
return self.SCHEMA_TYPES_ORDER

def unmarshal(self, value: Any) -> Any:
one_of_schema = self._get_one_of_schema(value)
if one_of_schema:
Expand All @@ -357,20 +371,7 @@ def unmarshal(self, value: Any) -> Any:
if all_of_schema:
return self.unmarshallers_factory.create(all_of_schema)(value)

for schema_type in self.SCHEMA_TYPES_ORDER:
unmarshaller = self.unmarshallers_factory.create(
self.schema, type_override=schema_type
)
# validate with validator of formatter (usualy type validator)
try:
unmarshaller._formatter_validate(value)
except ValidateError:
continue
else:
return unmarshaller(value)

log.warning("failed to unmarshal any type")
return value
return super().unmarshal(value)

def _get_one_of_schema(self, value: Any) -> Optional[Spec]:
if "oneOf" not in self.schema:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/unmarshalling/test_unmarshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def test_array_null(self, unmarshaller_factory):
spec = Spec.from_dict(schema)
value = None

with pytest.raises(TypeError):
with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(spec)(value)

def test_array_nullable(self, unmarshaller_factory):
Expand Down