Skip to content

Make x-model extension optional #431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
- id: check-hooks-apply

- repo: https://github.com/asottile/pyupgrade
rev: v2.19.0
rev: v2.38.4
hooks:
- id: pyupgrade
args: ["--py36-plus"]
Expand Down
24 changes: 22 additions & 2 deletions docs/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,35 @@ Extensions
x-model
-------

By default, objects are unmarshalled to dynamically created dataclasses. You can use your own dataclasses, pydantic models or models generated by third party generators (i.e. `datamodel-code-generator <https://github.com/koxudaxi/datamodel-code-generator>`__) by providing ``x-model`` property inside schema definition with location of your class.
By default, objects are unmarshalled to dictionaries. You can use dynamically created dataclasses.

.. code-block:: yaml

...
components:
schemas:
Coordinates:
x-model: foo.bar.Coordinates
x-model: Coordinates
type: object
required:
- lat
- lon
properties:
lat:
type: number
lon:
type: number


You can use your own dataclasses, pydantic models or models generated by third party generators (i.e. `datamodel-code-generator <https://github.com/koxudaxi/datamodel-code-generator>`__) by providing ``x-model-path`` property inside schema definition with location of your class.

.. code-block:: yaml

...
components:
schemas:
Coordinates:
x-model-path: foo.bar.Coordinates
type: object
required:
- lat
Expand Down
35 changes: 16 additions & 19 deletions openapi_core/extensions/models/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,40 @@
from typing import Type

from openapi_core.extensions.models.types import Field
from openapi_core.spec import Spec


class DictFactory:

base_class = dict

def create(self, fields: Iterable[Field]) -> Type[Dict[Any, Any]]:
def create(
self, schema: Spec, fields: Iterable[Field]
) -> Type[Dict[Any, Any]]:
return self.base_class


class DataClassFactory(DictFactory):
class ModelFactory(DictFactory):
def create(
self,
schema: Spec,
fields: Iterable[Field],
name: str = "Model",
) -> Type[Any]:
name = schema.getkey("x-model")
if name is None:
return super().create(schema, fields)

return make_dataclass(name, fields, frozen=True)


class ModelClassImporter(DataClassFactory):
class ModelPathFactory(ModelFactory):
def create(
self,
schema: Spec,
fields: Iterable[Field],
name: str = "Model",
model: Optional[str] = None,
) -> Any:
if model is None:
return super().create(fields, name=name)

model_class = self._get_class(model)
if model_class is not None:
return model_class

return super().create(fields, name=model)
model_class_path = schema.getkey("x-model-path")
if model_class_path is None:
return super().create(schema, fields)

def _get_class(self, model_class_path: str) -> Optional[object]:
try:
return locate(model_class_path)
except ErrorDuringImport:
return None
return locate(model_class_path)
9 changes: 4 additions & 5 deletions openapi_core/unmarshalling/schemas/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from openapi_schema_validator._format import oas30_format_checker
from openapi_schema_validator._types import is_string

from openapi_core.extensions.models.factories import ModelClassImporter
from openapi_core.extensions.models.factories import ModelPathFactory
from openapi_core.schema.schemas import get_all_properties
from openapi_core.spec import Spec
from openapi_core.unmarshalling.schemas.datatypes import FormattersDict
Expand Down Expand Up @@ -199,15 +199,14 @@ class ObjectUnmarshaller(ComplexUnmarshaller):
}

@property
def object_class_factory(self) -> ModelClassImporter:
return ModelClassImporter()
def object_class_factory(self) -> ModelPathFactory:
return ModelPathFactory()

def unmarshal(self, value: Any) -> Any:
properties = self.unmarshal_raw(value)

model = self.schema.getkey("x-model")
fields: Iterable[str] = properties and properties.keys() or []
object_class = self.object_class_factory.create(fields, model=model)
object_class = self.object_class_factory.create(self.schema, fields)

return object_class(**properties)

Expand Down
2 changes: 2 additions & 0 deletions tests/integration/data/v3.0/petstore.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ paths:
components:
schemas:
Coordinates:
x-model: Coordinates
type: object
required:
- lat
Expand All @@ -243,6 +244,7 @@ components:
lon:
type: number
Userdata:
x-model: Userdata
type: object
required:
- name
Expand Down
1 change: 1 addition & 0 deletions tests/integration/data/v3.0/read_only_write_only.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ paths:
components:
schemas:
User:
x-model: User
type: object
required:
- id
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/validation/test_petstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def test_get_pets_param_coordinates(self, spec):
assert is_dataclass(result.parameters.query["coordinates"])
assert (
result.parameters.query["coordinates"].__class__.__name__
== "Model"
== "Coordinates"
)
assert result.parameters.query["coordinates"].lat == coordinates["lat"]
assert result.parameters.query["coordinates"].lon == coordinates["lon"]
Expand Down Expand Up @@ -705,7 +705,8 @@ def test_post_birds(self, spec, spec_dict):

assert is_dataclass(result.parameters.cookie["userdata"])
assert (
result.parameters.cookie["userdata"].__class__.__name__ == "Model"
result.parameters.cookie["userdata"].__class__.__name__
== "Userdata"
)
assert result.parameters.cookie["userdata"].name == "user1"

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/validation/test_read_only_write_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_read_only_property_response(self, spec):

assert not result.errors
assert is_dataclass(result.data)
assert result.data.__class__.__name__ == "Model"
assert result.data.__class__.__name__ == "User"
assert result.data.id == 10
assert result.data.name == "Pedro"

Expand All @@ -73,7 +73,7 @@ def test_write_only_property(self, spec):

assert not result.errors
assert is_dataclass(result.body)
assert result.body.__class__.__name__ == "Model"
assert result.body.__class__.__name__ == "User"
assert result.body.name == "Pedro"
assert result.body.hidden == False

Expand Down
1 change: 1 addition & 0 deletions tests/integration/validation/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def test_request_object_deep_object_params(self, spec, spec_dict):
"in": "query",
"required": True,
"schema": {
"x-model": "paramObj",
"type": "object",
"properties": {
"count": {"type": "integer"},
Expand Down
15 changes: 9 additions & 6 deletions tests/unit/extensions/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import pytest

from openapi_core.extensions.models.factories import ModelClassImporter
from openapi_core.extensions.models.factories import ModelPathFactory
from openapi_core.spec import Spec


class TestImportModelCreate:
Expand All @@ -24,18 +25,20 @@ class BarModel:
del modules["foo"]

def test_dynamic_model(self):
factory = ModelClassImporter()
factory = ModelPathFactory()

test_model_class = factory.create(["name"], model="TestModel")
schema = Spec.from_dict({"x-model": "TestModel"})
test_model_class = factory.create(schema, ["name"])

assert is_dataclass(test_model_class)
assert test_model_class.__name__ == "TestModel"
assert list(test_model_class.__dataclass_fields__.keys()) == ["name"]
assert test_model_class.__dataclass_fields__["name"].type == str(Any)

def test_imported_model(self, loaded_model_class):
factory = ModelClassImporter()
def test_model_path(self, loaded_model_class):
factory = ModelPathFactory()

test_model_class = factory.create(["a", "b"], model="foo.BarModel")
schema = Spec.from_dict({"x-model-path": "foo.BarModel"})
test_model_class = factory.create(schema, ["a", "b"])

assert test_model_class == loaded_model_class
35 changes: 16 additions & 19 deletions tests/unit/unmarshalling/test_unmarshal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import uuid
from dataclasses import is_dataclass

import pytest
from isodate.tzinfo import UTC
Expand Down Expand Up @@ -540,8 +539,9 @@ def test_object_nullable(self, unmarshaller_factory):
value = {"foo": None}
result = unmarshaller_factory(spec)(value)

assert is_dataclass(result)
assert result.foo == None
assert result == {
"foo": None,
}

def test_schema_any_one_of(self, unmarshaller_factory):
schema = {
Expand Down Expand Up @@ -596,8 +596,9 @@ def test_schema_object_any_of(self, unmarshaller_factory):
spec = Spec.from_dict(schema)
result = unmarshaller_factory(spec)({"someint": 1})

assert is_dataclass(result)
assert result.someint == 1
assert result == {
"someint": 1,
}

def test_schema_object_any_of_invalid(self, unmarshaller_factory):
schema = {
Expand Down Expand Up @@ -728,14 +729,7 @@ def test_schema_free_form_object(

result = unmarshaller_factory(spec)(value)

assert is_dataclass(result)
for field, val in value.items():
result_field = getattr(result, field)
if isinstance(val, dict):
for field2, val2 in val.items():
assert getattr(result_field, field2) == val2
else:
assert result_field == val
assert result == value

def test_read_only_properties(self, unmarshaller_factory):
schema = {
Expand All @@ -755,8 +749,9 @@ def test_read_only_properties(self, unmarshaller_factory):
{"id": 10}
)

assert is_dataclass(result)
assert result.id == 10
assert result == {
"id": 10,
}

def test_read_only_properties_invalid(self, unmarshaller_factory):
schema = {
Expand Down Expand Up @@ -795,8 +790,9 @@ def test_write_only_properties(self, unmarshaller_factory):
{"id": 10}
)

assert is_dataclass(result)
assert result.id == 10
assert result == {
"id": 10,
}

def test_write_only_properties_invalid(self, unmarshaller_factory):
schema = {
Expand Down Expand Up @@ -825,5 +821,6 @@ def test_additional_properties_list(self, unmarshaller_factory):
{"user_ids": [1, 2, 3, 4]}
)

assert is_dataclass(result)
assert result.user_ids == [1, 2, 3, 4]
assert result == {
"user_ids": [1, 2, 3, 4],
}