diff --git a/openapi_core/unmarshalling/unmarshallers.py b/openapi_core/unmarshalling/unmarshallers.py index ddc8b891..984b9ea1 100644 --- a/openapi_core/unmarshalling/unmarshallers.py +++ b/openapi_core/unmarshalling/unmarshallers.py @@ -60,7 +60,8 @@ def __init__( schema_validators_factory = ( schema_unmarshallers_factory.schema_validators_factory ) - super().__init__( + BaseValidator.__init__( + self, spec, base_url=base_url, style_deserializers_factory=style_deserializers_factory, diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 34e23ecd..f2e1ae95 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -84,7 +84,9 @@ def __init__( ] = None, security_provider_factory: SecurityProviderFactory = security_provider_factory, ): - super().__init__( + + BaseValidator.__init__( + self, spec, base_url=base_url, style_deserializers_factory=style_deserializers_factory, diff --git a/tests/unit/unmarshalling/test_request_unmarshallers.py b/tests/unit/unmarshalling/test_request_unmarshallers.py new file mode 100644 index 00000000..a407d567 --- /dev/null +++ b/tests/unit/unmarshalling/test_request_unmarshallers.py @@ -0,0 +1,136 @@ +import enum + +import pytest +from jsonschema_path import SchemaPath + +from openapi_core import V30RequestUnmarshaller +from openapi_core import V31RequestUnmarshaller +from openapi_core.datatypes import Parameters +from openapi_core.testing import MockRequest + + +class Colors(enum.Enum): + + YELLOW = "yellow" + BLUE = "blue" + RED = "red" + + @classmethod + def of(cls, v: str): + for it in cls: + if it.value == v: + return it + raise ValueError(f"Invalid value: {v}") + + +class TestRequestUnmarshaller: + + @pytest.fixture(scope="session") + def spec_dict(self): + return { + "openapi": "3.1.0", + "info": { + "title": "Test request body unmarshaller", + "version": "0.1", + }, + "paths": { + "/resources": { + "post": { + "description": "POST resources test request", + "requestBody": { + "description": "", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/createResource" + } + } + }, + }, + "responses": { + "201": {"description": "Resource was created."} + }, + }, + "get": { + "description": "POST resources test request", + "parameters": [ + { + "name": "color", + "in": "query", + "required": False, + "schema": { + "$ref": "#/components/schemas/colors" + }, + }, + ], + "responses": { + "default": { + "description": "Returned resources matching request." + } + }, + }, + } + }, + "components": { + "schemas": { + "colors": { + "type": "string", + "enum": ["yellow", "blue", "red"], + "format": "enum_Colors", + }, + "createResource": { + "type": "object", + "properties": { + "resId": {"type": "integer"}, + "color": {"$ref": "#/components/schemas/colors"}, + }, + "required": ["resId", "color"], + }, + } + }, + } + + @pytest.fixture(scope="session") + def spec(self, spec_dict): + return SchemaPath.from_dict(spec_dict) + + @pytest.mark.parametrize( + "req_unmarshaller_cls", + [V30RequestUnmarshaller, V31RequestUnmarshaller], + ) + def test_request_body_extra_unmarshaller(self, spec, req_unmarshaller_cls): + ru = req_unmarshaller_cls( + spec=spec, extra_format_unmarshallers={"enum_Colors": Colors.of} + ) + request = MockRequest( + host_url="http://example.com", + method="post", + path="/resources", + data=b'{"resId": 23498572, "color": "blue"}', + ) + result = ru.unmarshal(request) + + assert not result.errors + assert result.body == {"resId": 23498572, "color": Colors.BLUE} + assert result.parameters == Parameters() + + @pytest.mark.parametrize( + "req_unmarshaller_cls", + [V30RequestUnmarshaller, V31RequestUnmarshaller], + ) + def test_request_param_extra_unmarshaller( + self, spec, req_unmarshaller_cls + ): + ru = req_unmarshaller_cls( + spec=spec, extra_format_unmarshallers={"enum_Colors": Colors.of} + ) + request = MockRequest( + host_url="http://example.com", + method="get", + path="/resources", + args={"color": "blue"}, + ) + result = ru.unmarshal(request) + + assert not result.errors + assert result.parameters == Parameters(query=dict(color=Colors.BLUE))