Skip to content

Commit 0d00322

Browse files
authored
Merge pull request #431 from p1c2u/fix/make-x-model-extention-optional
Make x-model extension optional
2 parents ac64879 + 417f76e commit 0d00322

File tree

11 files changed

+77
-56
lines changed

11 files changed

+77
-56
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
- id: check-hooks-apply
1111

1212
- repo: https://github.com/asottile/pyupgrade
13-
rev: v2.19.0
13+
rev: v2.38.4
1414
hooks:
1515
- id: pyupgrade
1616
args: ["--py36-plus"]

docs/extensions.rst

+22-2
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,35 @@ Extensions
44
x-model
55
-------
66

7-
By default, objects are unmarshalled to dynamically created dataclasses. You can use your own dataclasses, pydantic models or models generated by third party generators (i.e. `datamodel-code-generator <https://github.com/koxudaxi/datamodel-code-generator>`__) by providing ``x-model`` property inside schema definition with location of your class.
7+
By default, objects are unmarshalled to dictionaries. You can use dynamically created dataclasses.
88

99
.. code-block:: yaml
1010
1111
...
1212
components:
1313
schemas:
1414
Coordinates:
15-
x-model: foo.bar.Coordinates
15+
x-model: Coordinates
16+
type: object
17+
required:
18+
- lat
19+
- lon
20+
properties:
21+
lat:
22+
type: number
23+
lon:
24+
type: number
25+
26+
27+
You can use your own dataclasses, pydantic models or models generated by third party generators (i.e. `datamodel-code-generator <https://github.com/koxudaxi/datamodel-code-generator>`__) by providing ``x-model-path`` property inside schema definition with location of your class.
28+
29+
.. code-block:: yaml
30+
31+
...
32+
components:
33+
schemas:
34+
Coordinates:
35+
x-model-path: foo.bar.Coordinates
1636
type: object
1737
required:
1838
- lat

openapi_core/extensions/models/factories.py

+16-19
Original file line numberDiff line numberDiff line change
@@ -9,43 +9,40 @@
99
from typing import Type
1010

1111
from openapi_core.extensions.models.types import Field
12+
from openapi_core.spec import Spec
1213

1314

1415
class DictFactory:
1516

1617
base_class = dict
1718

18-
def create(self, fields: Iterable[Field]) -> Type[Dict[Any, Any]]:
19+
def create(
20+
self, schema: Spec, fields: Iterable[Field]
21+
) -> Type[Dict[Any, Any]]:
1922
return self.base_class
2023

2124

22-
class DataClassFactory(DictFactory):
25+
class ModelFactory(DictFactory):
2326
def create(
2427
self,
28+
schema: Spec,
2529
fields: Iterable[Field],
26-
name: str = "Model",
2730
) -> Type[Any]:
31+
name = schema.getkey("x-model")
32+
if name is None:
33+
return super().create(schema, fields)
34+
2835
return make_dataclass(name, fields, frozen=True)
2936

3037

31-
class ModelClassImporter(DataClassFactory):
38+
class ModelPathFactory(ModelFactory):
3239
def create(
3340
self,
41+
schema: Spec,
3442
fields: Iterable[Field],
35-
name: str = "Model",
36-
model: Optional[str] = None,
3743
) -> Any:
38-
if model is None:
39-
return super().create(fields, name=name)
40-
41-
model_class = self._get_class(model)
42-
if model_class is not None:
43-
return model_class
44-
45-
return super().create(fields, name=model)
44+
model_class_path = schema.getkey("x-model-path")
45+
if model_class_path is None:
46+
return super().create(schema, fields)
4647

47-
def _get_class(self, model_class_path: str) -> Optional[object]:
48-
try:
49-
return locate(model_class_path)
50-
except ErrorDuringImport:
51-
return None
48+
return locate(model_class_path)

openapi_core/unmarshalling/schemas/unmarshallers.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from openapi_schema_validator._format import oas30_format_checker
1919
from openapi_schema_validator._types import is_string
2020

21-
from openapi_core.extensions.models.factories import ModelClassImporter
21+
from openapi_core.extensions.models.factories import ModelPathFactory
2222
from openapi_core.schema.schemas import get_all_properties
2323
from openapi_core.spec import Spec
2424
from openapi_core.unmarshalling.schemas.datatypes import FormattersDict
@@ -199,15 +199,14 @@ class ObjectUnmarshaller(ComplexUnmarshaller):
199199
}
200200

201201
@property
202-
def object_class_factory(self) -> ModelClassImporter:
203-
return ModelClassImporter()
202+
def object_class_factory(self) -> ModelPathFactory:
203+
return ModelPathFactory()
204204

205205
def unmarshal(self, value: Any) -> Any:
206206
properties = self.unmarshal_raw(value)
207207

208-
model = self.schema.getkey("x-model")
209208
fields: Iterable[str] = properties and properties.keys() or []
210-
object_class = self.object_class_factory.create(fields, model=model)
209+
object_class = self.object_class_factory.create(self.schema, fields)
211210

212211
return object_class(**properties)
213212

tests/integration/data/v3.0/petstore.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ paths:
233233
components:
234234
schemas:
235235
Coordinates:
236+
x-model: Coordinates
236237
type: object
237238
required:
238239
- lat
@@ -243,6 +244,7 @@ components:
243244
lon:
244245
type: number
245246
Userdata:
247+
x-model: Userdata
246248
type: object
247249
required:
248250
- name

tests/integration/data/v3.0/read_only_write_only.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ paths:
2323
components:
2424
schemas:
2525
User:
26+
x-model: User
2627
type: object
2728
required:
2829
- id

tests/integration/validation/test_petstore.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def test_get_pets_param_coordinates(self, spec):
644644
assert is_dataclass(result.parameters.query["coordinates"])
645645
assert (
646646
result.parameters.query["coordinates"].__class__.__name__
647-
== "Model"
647+
== "Coordinates"
648648
)
649649
assert result.parameters.query["coordinates"].lat == coordinates["lat"]
650650
assert result.parameters.query["coordinates"].lon == coordinates["lon"]
@@ -705,7 +705,8 @@ def test_post_birds(self, spec, spec_dict):
705705

706706
assert is_dataclass(result.parameters.cookie["userdata"])
707707
assert (
708-
result.parameters.cookie["userdata"].__class__.__name__ == "Model"
708+
result.parameters.cookie["userdata"].__class__.__name__
709+
== "Userdata"
709710
)
710711
assert result.parameters.cookie["userdata"].name == "user1"
711712

tests/integration/validation/test_read_only_write_only.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_read_only_property_response(self, spec):
5151

5252
assert not result.errors
5353
assert is_dataclass(result.data)
54-
assert result.data.__class__.__name__ == "Model"
54+
assert result.data.__class__.__name__ == "User"
5555
assert result.data.id == 10
5656
assert result.data.name == "Pedro"
5757

@@ -73,7 +73,7 @@ def test_write_only_property(self, spec):
7373

7474
assert not result.errors
7575
assert is_dataclass(result.body)
76-
assert result.body.__class__.__name__ == "Model"
76+
assert result.body.__class__.__name__ == "User"
7777
assert result.body.name == "Pedro"
7878
assert result.body.hidden == False
7979

tests/integration/validation/test_validators.py

+1
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def test_request_object_deep_object_params(self, spec, spec_dict):
536536
"in": "query",
537537
"required": True,
538538
"schema": {
539+
"x-model": "paramObj",
539540
"type": "object",
540541
"properties": {
541542
"count": {"type": "integer"},

tests/unit/extensions/test_factories.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import pytest
88

9-
from openapi_core.extensions.models.factories import ModelClassImporter
9+
from openapi_core.extensions.models.factories import ModelPathFactory
10+
from openapi_core.spec import Spec
1011

1112

1213
class TestImportModelCreate:
@@ -24,18 +25,20 @@ class BarModel:
2425
del modules["foo"]
2526

2627
def test_dynamic_model(self):
27-
factory = ModelClassImporter()
28+
factory = ModelPathFactory()
2829

29-
test_model_class = factory.create(["name"], model="TestModel")
30+
schema = Spec.from_dict({"x-model": "TestModel"})
31+
test_model_class = factory.create(schema, ["name"])
3032

3133
assert is_dataclass(test_model_class)
3234
assert test_model_class.__name__ == "TestModel"
3335
assert list(test_model_class.__dataclass_fields__.keys()) == ["name"]
3436
assert test_model_class.__dataclass_fields__["name"].type == str(Any)
3537

36-
def test_imported_model(self, loaded_model_class):
37-
factory = ModelClassImporter()
38+
def test_model_path(self, loaded_model_class):
39+
factory = ModelPathFactory()
3840

39-
test_model_class = factory.create(["a", "b"], model="foo.BarModel")
41+
schema = Spec.from_dict({"x-model-path": "foo.BarModel"})
42+
test_model_class = factory.create(schema, ["a", "b"])
4043

4144
assert test_model_class == loaded_model_class

tests/unit/unmarshalling/test_unmarshal.py

+16-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import datetime
22
import uuid
3-
from dataclasses import is_dataclass
43

54
import pytest
65
from isodate.tzinfo import UTC
@@ -540,8 +539,9 @@ def test_object_nullable(self, unmarshaller_factory):
540539
value = {"foo": None}
541540
result = unmarshaller_factory(spec)(value)
542541

543-
assert is_dataclass(result)
544-
assert result.foo == None
542+
assert result == {
543+
"foo": None,
544+
}
545545

546546
def test_schema_any_one_of(self, unmarshaller_factory):
547547
schema = {
@@ -596,8 +596,9 @@ def test_schema_object_any_of(self, unmarshaller_factory):
596596
spec = Spec.from_dict(schema)
597597
result = unmarshaller_factory(spec)({"someint": 1})
598598

599-
assert is_dataclass(result)
600-
assert result.someint == 1
599+
assert result == {
600+
"someint": 1,
601+
}
601602

602603
def test_schema_object_any_of_invalid(self, unmarshaller_factory):
603604
schema = {
@@ -728,14 +729,7 @@ def test_schema_free_form_object(
728729

729730
result = unmarshaller_factory(spec)(value)
730731

731-
assert is_dataclass(result)
732-
for field, val in value.items():
733-
result_field = getattr(result, field)
734-
if isinstance(val, dict):
735-
for field2, val2 in val.items():
736-
assert getattr(result_field, field2) == val2
737-
else:
738-
assert result_field == val
732+
assert result == value
739733

740734
def test_read_only_properties(self, unmarshaller_factory):
741735
schema = {
@@ -755,8 +749,9 @@ def test_read_only_properties(self, unmarshaller_factory):
755749
{"id": 10}
756750
)
757751

758-
assert is_dataclass(result)
759-
assert result.id == 10
752+
assert result == {
753+
"id": 10,
754+
}
760755

761756
def test_read_only_properties_invalid(self, unmarshaller_factory):
762757
schema = {
@@ -795,8 +790,9 @@ def test_write_only_properties(self, unmarshaller_factory):
795790
{"id": 10}
796791
)
797792

798-
assert is_dataclass(result)
799-
assert result.id == 10
793+
assert result == {
794+
"id": 10,
795+
}
800796

801797
def test_write_only_properties_invalid(self, unmarshaller_factory):
802798
schema = {
@@ -825,5 +821,6 @@ def test_additional_properties_list(self, unmarshaller_factory):
825821
{"user_ids": [1, 2, 3, 4]}
826822
)
827823

828-
assert is_dataclass(result)
829-
assert result.user_ids == [1, 2, 3, 4]
824+
assert result == {
825+
"user_ids": [1, 2, 3, 4],
826+
}

0 commit comments

Comments
 (0)