diff --git a/docs/customizations.rst b/docs/customizations.rst index 39276715..f3e259cd 100644 --- a/docs/customizations.rst +++ b/docs/customizations.rst @@ -61,7 +61,7 @@ Here's how you could add support for a ``usdate`` format that handles dates of t def validate(self, value) -> bool: return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value)) - def unmarshal(self, value): + def format(self, value): return datetime.strptime(value, "%m/%d/%y").date diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index 66184cba..41e3e3aa 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -108,14 +108,6 @@ def create( klass = self.UNMARSHALLERS[schema_type] return klass(schema, validator, formatter) - def get_formatter( - self, type_format: str, default_formatters: FormattersDict - ) -> Optional[Formatter]: - try: - return self.custom_formatters[type_format] - except KeyError: - return default_formatters.get(type_format) - def get_validator(self, schema: Spec) -> Validator: resolver = schema.accessor.resolver # type: ignore custom_format_checks = { diff --git a/openapi_core/unmarshalling/schemas/formatters.py b/openapi_core/unmarshalling/schemas/formatters.py index 47dd52b8..b0a398f8 100644 --- a/openapi_core/unmarshalling/schemas/formatters.py +++ b/openapi_core/unmarshalling/schemas/formatters.py @@ -1,3 +1,4 @@ +import warnings from typing import Any from typing import Callable from typing import Optional @@ -8,20 +9,49 @@ class Formatter: def validate(self, value: Any) -> bool: return True - def unmarshal(self, value: Any) -> Any: + def format(self, value: Any) -> Any: return value + def __getattribute__(self, name: str) -> Any: + if name == "unmarshal": + warnings.warn( + "Unmarshal method is deprecated. " "Use format instead.", + DeprecationWarning, + ) + return super().__getattribute__("format") + if name == "format": + try: + attr = super().__getattribute__("unmarshal") + except AttributeError: + return super().__getattribute__("format") + else: + warnings.warn( + "Unmarshal method is deprecated. " + "Rename unmarshal method to format instead.", + DeprecationWarning, + ) + return attr + return super().__getattribute__(name) + @classmethod def from_callables( cls, - validate: Optional[Callable[[Any], Any]] = None, + validate_callable: Optional[Callable[[Any], Any]] = None, + format_callable: Optional[Callable[[Any], Any]] = None, unmarshal: Optional[Callable[[Any], Any]] = None, ) -> "Formatter": attrs = {} - if validate is not None: - attrs["validate"] = staticmethod(validate) + if validate_callable is not None: + attrs["validate"] = staticmethod(validate_callable) + if format_callable is not None: + attrs["format"] = staticmethod(format_callable) if unmarshal is not None: - attrs["unmarshal"] = staticmethod(unmarshal) + warnings.warn( + "Unmarshal parameter is deprecated. " + "Use format_callable instead.", + DeprecationWarning, + ) + attrs["format"] = staticmethod(unmarshal) klass: Type[Formatter] = type("Formatter", (cls,), attrs) return klass() diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index c2704a5c..941e28cb 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import cast @@ -31,6 +32,7 @@ ) from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError +from openapi_core.unmarshalling.schemas.exceptions import UnmarshallerError from openapi_core.unmarshalling.schemas.exceptions import ValidateError from openapi_core.unmarshalling.schemas.formatters import Formatter from openapi_core.unmarshalling.schemas.util import format_byte @@ -61,24 +63,25 @@ def __init__( ): self.schema = schema self.validator = validator - self.format = schema.getkey("format") + self.schema_format = schema.getkey("format") if formatter is None: - if self.format not in self.FORMATTERS: - raise FormatterNotFoundError(self.format) - self.formatter = self.FORMATTERS[self.format] + if self.schema_format not in self.FORMATTERS: + raise FormatterNotFoundError(self.schema_format) + self.formatter = self.FORMATTERS[self.schema_format] else: self.formatter = formatter def __call__(self, value: Any) -> Any: - if value is None: - return - self.validate(value) + # skip unmarshalling for nullable in OpenAPI 3.0 + if value is None and self.schema.getkey("nullable", False): + return value + return self.unmarshal(value) - def _formatter_validate(self, value: Any) -> None: + def _validate_format(self, value: Any) -> None: result = self.formatter.validate(value) if not result: schema_type = self.schema.getkey("type", "any") @@ -91,11 +94,14 @@ def validate(self, value: Any) -> None: schema_type = self.schema.getkey("type", "any") raise InvalidSchemaValue(value, schema_type, schema_errors=errors) - def unmarshal(self, value: Any) -> Any: + def format(self, value: Any) -> Any: try: - return self.formatter.unmarshal(value) - except ValueError as exc: - raise InvalidSchemaFormatValue(value, self.format, exc) + return self.formatter.format(value) + except (ValueError, TypeError) as exc: + raise InvalidSchemaFormatValue(value, self.schema_format, exc) + + def unmarshal(self, value: Any) -> Any: + return self.format(value) class StringUnmarshaller(BaseSchemaUnmarshaller): @@ -192,10 +198,8 @@ def items_unmarshaller(self) -> "BaseSchemaUnmarshaller": items_schema = self.schema.get("items", Spec.from_dict({})) return self.unmarshallers_factory.create(items_schema) - def __call__(self, value: Any) -> Optional[List[Any]]: - value = super().__call__(value) - if value is None and self.schema.getkey("nullable", False): - return None + def unmarshal(self, value: Any) -> Optional[List[Any]]: + value = super().unmarshal(value) return list(map(self.items_unmarshaller, value)) @@ -210,21 +214,16 @@ def object_class_factory(self) -> ModelPathFactory: return ModelPathFactory() def unmarshal(self, value: Any) -> Any: - properties = self.unmarshal_raw(value) + properties = self.format(value) fields: Iterable[str] = properties and properties.keys() or [] object_class = self.object_class_factory.create(self.schema, fields) return object_class(**properties) - def unmarshal_raw(self, value: Any) -> Any: - try: - value = self.formatter.unmarshal(value) - except ValueError as exc: - schema_format = self.schema.getkey("format") - raise InvalidSchemaFormatValue(value, schema_format, exc) - else: - return self._unmarshal_object(value) + def format(self, value: Any) -> Any: + formatted = super().format(value) + return self._unmarshal_properties(formatted) def _clone(self, schema: Spec) -> "ObjectUnmarshaller": return cast( @@ -232,16 +231,14 @@ def _clone(self, schema: Spec) -> "ObjectUnmarshaller": self.unmarshallers_factory.create(schema, "object"), ) - def _unmarshal_object(self, value: Any) -> Any: + def _unmarshal_properties(self, value: Any) -> Any: properties = {} if "oneOf" in self.schema: one_of_properties = None for one_of_schema in self.schema / "oneOf": try: - unmarshalled = self._clone(one_of_schema).unmarshal_raw( - value - ) + unmarshalled = self._clone(one_of_schema).format(value) except (UnmarshalError, ValueError): pass else: @@ -259,9 +256,7 @@ def _unmarshal_object(self, value: Any) -> Any: any_of_properties = None for any_of_schema in self.schema / "anyOf": try: - unmarshalled = self._clone(any_of_schema).unmarshal_raw( - value - ) + unmarshalled = self._clone(any_of_schema).format(value) except (UnmarshalError, ValueError): pass else: @@ -319,21 +314,36 @@ def types_unmarshallers(self) -> List["BaseSchemaUnmarshaller"]: unmarshaller = partial(self.unmarshallers_factory.create, self.schema) return list(map(unmarshaller, types)) - def unmarshal(self, value: Any) -> Any: - for unmarshaller in self.types_unmarshallers: + @property + def type(self) -> List[str]: + types = self.schema.getkey("type", ["any"]) + assert isinstance(types, list) + return types + + def _get_unmarshallers_iter(self) -> Iterator["BaseSchemaUnmarshaller"]: + for schema_type in self.type: + yield self.unmarshallers_factory.create( + self.schema, type_override=schema_type + ) + + def _get_best_unmarshaller(self, value: Any) -> "BaseSchemaUnmarshaller": + for unmarshaller in self._get_unmarshallers_iter(): # validate with validator of formatter (usualy type validator) try: - unmarshaller._formatter_validate(value) + unmarshaller._validate_format(value) except ValidateError: continue else: - return unmarshaller(value) + return unmarshaller - log.warning("failed to unmarshal multi type") - return value + raise UnmarshallerError("Unmarshaller not found for type(s)") + + def unmarshal(self, value: Any) -> Any: + unmarshaller = self._get_best_unmarshaller(value) + return unmarshaller(value) -class AnyUnmarshaller(ComplexUnmarshaller): +class AnyUnmarshaller(MultiTypeUnmarshaller): SCHEMA_TYPES_ORDER = [ "object", @@ -344,6 +354,10 @@ class AnyUnmarshaller(ComplexUnmarshaller): "string", ] + @property + def type(self) -> List[str]: + return self.SCHEMA_TYPES_ORDER + def unmarshal(self, value: Any) -> Any: one_of_schema = self._get_one_of_schema(value) if one_of_schema: @@ -357,20 +371,7 @@ def unmarshal(self, value: Any) -> Any: if all_of_schema: return self.unmarshallers_factory.create(all_of_schema)(value) - for schema_type in self.SCHEMA_TYPES_ORDER: - unmarshaller = self.unmarshallers_factory.create( - self.schema, type_override=schema_type - ) - # validate with validator of formatter (usualy type validator) - try: - unmarshaller._formatter_validate(value) - except ValidateError: - continue - else: - return unmarshaller(value) - - log.warning("failed to unmarshal any type") - return value + return super().unmarshal(value) def _get_one_of_schema(self, value: Any) -> Optional[Spec]: if "oneOf" not in self.schema: diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index 3ce50db4..e0ef91c3 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -406,7 +406,7 @@ def test_array_null(self, unmarshaller_factory): spec = Spec.from_dict(schema) value = None - with pytest.raises(TypeError): + with pytest.raises(InvalidSchemaValue): unmarshaller_factory(spec)(value) def test_array_nullable(self, unmarshaller_factory):