|
1 | 1 | import warnings
|
2 | 2 | from typing import Any
|
| 3 | +from typing import Dict |
3 | 4 | from typing import Optional
|
| 5 | +from typing import cast |
4 | 6 | from xml.etree.ElementTree import ParseError
|
5 | 7 |
|
| 8 | +from jsonschema_path import SchemaPath |
| 9 | + |
6 | 10 | from openapi_core.deserializing.media_types.datatypes import (
|
7 | 11 | DeserializerCallable,
|
8 | 12 | )
|
| 13 | +from openapi_core.deserializing.media_types.datatypes import ( |
| 14 | + MediaTypeDeserializersDict, |
| 15 | +) |
9 | 16 | from openapi_core.deserializing.media_types.exceptions import (
|
10 | 17 | MediaTypeDeserializeError,
|
11 | 18 | )
|
| 19 | +from openapi_core.schema.encodings import get_encoding_default_content_type |
| 20 | + |
| 21 | + |
| 22 | +class ContentTypesDeserializer: |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + media_type_deserializers: Optional[MediaTypeDeserializersDict] = None, |
| 26 | + extra_media_type_deserializers: Optional[ |
| 27 | + MediaTypeDeserializersDict |
| 28 | + ] = None, |
| 29 | + ): |
| 30 | + if media_type_deserializers is None: |
| 31 | + media_type_deserializers = {} |
| 32 | + self.media_type_deserializers = media_type_deserializers |
| 33 | + if extra_media_type_deserializers is None: |
| 34 | + extra_media_type_deserializers = {} |
| 35 | + self.extra_media_type_deserializers = extra_media_type_deserializers |
| 36 | + |
| 37 | + def deserialize(self, mimetype: str, value: Any, **parameters: str) -> Any: |
| 38 | + deserializer_callable = self.get_deserializer_callable(mimetype) |
| 39 | + if deserializer_callable is None: |
| 40 | + warnings.warn(f"Unsupported {mimetype} mimetype") |
| 41 | + return value |
| 42 | + |
| 43 | + try: |
| 44 | + return deserializer_callable(value, **parameters) |
| 45 | + except (ParseError, ValueError, TypeError, AttributeError): |
| 46 | + raise MediaTypeDeserializeError(mimetype, value) |
| 47 | + |
| 48 | + def get_deserializer_callable( |
| 49 | + self, |
| 50 | + mimetype: str, |
| 51 | + ) -> Optional[DeserializerCallable]: |
| 52 | + if mimetype in self.extra_media_type_deserializers: |
| 53 | + return self.extra_media_type_deserializers[mimetype] |
| 54 | + return self.media_type_deserializers.get(mimetype) |
12 | 55 |
|
13 | 56 |
|
14 |
| -class CallableMediaTypeDeserializer: |
| 57 | +class MediaTypeDeserializer: |
15 | 58 | def __init__(
|
16 | 59 | self,
|
17 | 60 | mimetype: str,
|
18 |
| - deserializer_callable: Optional[DeserializerCallable] = None, |
| 61 | + content_types_deserializers: ContentTypesDeserializer, |
| 62 | + schema: Optional[SchemaPath] = None, |
| 63 | + encoding: Optional[SchemaPath] = None, |
19 | 64 | **parameters: str,
|
20 | 65 | ):
|
| 66 | + self.schema = schema |
21 | 67 | self.mimetype = mimetype
|
22 |
| - self.deserializer_callable = deserializer_callable |
| 68 | + self.content_types_deserializers = content_types_deserializers |
| 69 | + self.encoding = encoding |
23 | 70 | self.parameters = parameters
|
24 | 71 |
|
25 | 72 | def deserialize(self, value: Any) -> Any:
|
26 |
| - if self.deserializer_callable is None: |
27 |
| - warnings.warn(f"Unsupported {self.mimetype} mimetype") |
28 |
| - return value |
| 73 | + deserialized = self.content_types_deserializers.deserialize( |
| 74 | + self.mimetype, value, **self.parameters |
| 75 | + ) |
29 | 76 |
|
30 |
| - try: |
31 |
| - return self.deserializer_callable(value, **self.parameters) |
32 |
| - except (ParseError, ValueError, TypeError, AttributeError): |
33 |
| - raise MediaTypeDeserializeError(self.mimetype, value) |
| 77 | + if ( |
| 78 | + self.mimetype != "application/x-www-form-urlencoded" |
| 79 | + and not self.mimetype.startswith("multipart") |
| 80 | + ): |
| 81 | + return deserialized |
| 82 | + |
| 83 | + return self.decode(deserialized) |
| 84 | + |
| 85 | + def evolve( |
| 86 | + self, mimetype: str, schema: Optional[SchemaPath] |
| 87 | + ) -> "MediaTypeDeserializer": |
| 88 | + cls = self.__class__ |
| 89 | + |
| 90 | + return cls( |
| 91 | + mimetype, |
| 92 | + self.content_types_deserializers, |
| 93 | + schema=schema, |
| 94 | + ) |
| 95 | + |
| 96 | + def decode(self, value: Dict[str, Any]) -> Dict[str, Any]: |
| 97 | + return { |
| 98 | + prop_name: self.decode_property(prop_name, prop_value) |
| 99 | + for prop_name, prop_value in value.items() |
| 100 | + } |
| 101 | + |
| 102 | + def decode_property(self, prop_name: str, value: Any) -> Any: |
| 103 | + # schema is required for multipart |
| 104 | + assert self.schema |
| 105 | + schema_props = self.schema.get("properties") |
| 106 | + prop_schema = None |
| 107 | + if schema_props is not None and prop_name in schema_props: |
| 108 | + prop_schema = cast( |
| 109 | + Optional[SchemaPath], |
| 110 | + schema_props.get(prop_name), |
| 111 | + ) |
| 112 | + prop_content_type = self.get_property_content_type( |
| 113 | + prop_name, prop_schema |
| 114 | + ) |
| 115 | + prop_deserializer = self.evolve( |
| 116 | + prop_content_type, |
| 117 | + prop_schema, |
| 118 | + ) |
| 119 | + return prop_deserializer.deserialize(value) |
| 120 | + |
| 121 | + def get_property_content_type( |
| 122 | + self, prop_name: str, prop_schema: Optional[SchemaPath] = None |
| 123 | + ) -> str: |
| 124 | + if self.encoding is None: |
| 125 | + return get_encoding_default_content_type(prop_schema) |
| 126 | + |
| 127 | + if prop_name not in self.encoding: |
| 128 | + return get_encoding_default_content_type(prop_schema) |
| 129 | + |
| 130 | + prep_encoding = self.encoding.get(prop_name) |
| 131 | + prop_content_type = prep_encoding.getkey("contentType") |
| 132 | + if prop_content_type is None: |
| 133 | + return get_encoding_default_content_type(prop_schema) |
| 134 | + |
| 135 | + return cast(str, prop_content_type) |
0 commit comments