6
6
from mypy .plugin import (
7
7
Plugin , FunctionContext , MethodContext , MethodSigContext , AttributeContext , ClassDefContext
8
8
)
9
- from mypy .plugins .common import try_getting_str_literal
9
+ from mypy .plugins .common import try_getting_str_literals
10
10
from mypy .types import (
11
11
Type , Instance , AnyType , TypeOfAny , CallableType , NoneTyp , UnionType , TypedDictType ,
12
12
TypeVarType
13
13
)
14
+ from mypy .subtypes import is_subtype
14
15
15
16
16
17
class DefaultPlugin (Plugin ):
@@ -171,26 +172,34 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
171
172
if (isinstance (ctx .type , TypedDictType )
172
173
and len (ctx .arg_types ) >= 1
173
174
and len (ctx .arg_types [0 ]) == 1 ):
174
- key = try_getting_str_literal (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
175
- if key is None :
175
+ keys = try_getting_str_literals (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
176
+ if keys is None :
176
177
return ctx .default_return_type
177
178
178
- value_type = ctx .type .items .get (key )
179
- if value_type :
179
+ output_types = []
180
+ for key in keys :
181
+ value_type = ctx .type .items .get (key )
182
+ if value_type is None :
183
+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
184
+ return AnyType (TypeOfAny .from_error )
185
+
180
186
if len (ctx .arg_types ) == 1 :
181
- return UnionType . make_simplified_union ([ value_type , NoneTyp ()] )
187
+ output_types . append ( value_type )
182
188
elif (len (ctx .arg_types ) == 2 and len (ctx .arg_types [1 ]) == 1
183
189
and len (ctx .args [1 ]) == 1 ):
184
190
default_arg = ctx .args [1 ][0 ]
185
191
if (isinstance (default_arg , DictExpr ) and len (default_arg .items ) == 0
186
192
and isinstance (value_type , TypedDictType )):
187
193
# Special case '{}' as the default for a typed dict type.
188
- return value_type .copy_modified (required_keys = set ())
194
+ output_types . append ( value_type .copy_modified (required_keys = set () ))
189
195
else :
190
- return UnionType .make_simplified_union ([value_type , ctx .arg_types [1 ][0 ]])
191
- else :
192
- ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
193
- return AnyType (TypeOfAny .from_error )
196
+ output_types .append (value_type )
197
+ output_types .append (ctx .arg_types [1 ][0 ])
198
+
199
+ if len (ctx .arg_types ) == 1 :
200
+ output_types .append (NoneTyp ())
201
+
202
+ return UnionType .make_simplified_union (output_types )
194
203
return ctx .default_return_type
195
204
196
205
@@ -228,23 +237,28 @@ def typed_dict_pop_callback(ctx: MethodContext) -> Type:
228
237
if (isinstance (ctx .type , TypedDictType )
229
238
and len (ctx .arg_types ) >= 1
230
239
and len (ctx .arg_types [0 ]) == 1 ):
231
- key = try_getting_str_literal (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
232
- if key is None :
240
+ keys = try_getting_str_literals (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
241
+ if keys is None :
233
242
ctx .api .fail (message_registry .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
234
243
return AnyType (TypeOfAny .from_error )
235
244
236
- if key in ctx .type .required_keys :
237
- ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
238
- value_type = ctx .type .items .get (key )
239
- if value_type :
240
- if len (ctx .args [1 ]) == 0 :
241
- return value_type
242
- elif (len (ctx .arg_types ) == 2 and len (ctx .arg_types [1 ]) == 1
243
- and len (ctx .args [1 ]) == 1 ):
244
- return UnionType .make_simplified_union ([value_type , ctx .arg_types [1 ][0 ]])
245
- else :
246
- ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
247
- return AnyType (TypeOfAny .from_error )
245
+ value_types = []
246
+ for key in keys :
247
+ if key in ctx .type .required_keys :
248
+ ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
249
+
250
+ value_type = ctx .type .items .get (key )
251
+ if value_type :
252
+ value_types .append (value_type )
253
+ else :
254
+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
255
+ return AnyType (TypeOfAny .from_error )
256
+
257
+ if len (ctx .args [1 ]) == 0 :
258
+ return UnionType .make_simplified_union (value_types )
259
+ elif (len (ctx .arg_types ) == 2 and len (ctx .arg_types [1 ]) == 1
260
+ and len (ctx .args [1 ]) == 1 ):
261
+ return UnionType .make_simplified_union ([* value_types , ctx .arg_types [1 ][0 ]])
248
262
return ctx .default_return_type
249
263
250
264
@@ -273,18 +287,35 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
273
287
"""Type check TypedDict.setdefault and infer a precise return type."""
274
288
if (isinstance (ctx .type , TypedDictType )
275
289
and len (ctx .arg_types ) == 2
276
- and len (ctx .arg_types [0 ]) == 1 ):
277
- key = try_getting_str_literal (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
278
- if key is None :
290
+ and len (ctx .arg_types [0 ]) == 1
291
+ and len (ctx .arg_types [1 ]) == 1 ):
292
+ keys = try_getting_str_literals (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
293
+ if keys is None :
279
294
ctx .api .fail (message_registry .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
280
295
return AnyType (TypeOfAny .from_error )
281
296
282
- value_type = ctx .type .items .get (key )
283
- if value_type :
284
- return value_type
285
- else :
286
- ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
287
- return AnyType (TypeOfAny .from_error )
297
+ default_type = ctx .arg_types [1 ][0 ]
298
+
299
+ value_types = []
300
+ for key in keys :
301
+ value_type = ctx .type .items .get (key )
302
+
303
+ if value_type is None :
304
+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
305
+ return AnyType (TypeOfAny .from_error )
306
+
307
+ # The signature_callback above can't always infer the right signature
308
+ # (e.g. when the expression is a variable that happens to be a Literal str)
309
+ # so we need to handle the check ourselves here and make sure the provided
310
+ # default can be assigned to all key-value pairs we're updating.
311
+ if not is_subtype (default_type , value_type ):
312
+ ctx .api .msg .typeddict_setdefault_arguments_inconsistent (
313
+ default_type , value_type , ctx .context )
314
+ return AnyType (TypeOfAny .from_error )
315
+
316
+ value_types .append (value_type )
317
+
318
+ return UnionType .make_simplified_union (value_types )
288
319
return ctx .default_return_type
289
320
290
321
@@ -299,15 +330,16 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
299
330
if (isinstance (ctx .type , TypedDictType )
300
331
and len (ctx .arg_types ) == 1
301
332
and len (ctx .arg_types [0 ]) == 1 ):
302
- key = try_getting_str_literal (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
303
- if key is None :
333
+ keys = try_getting_str_literals (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
334
+ if keys is None :
304
335
ctx .api .fail (message_registry .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
305
336
return AnyType (TypeOfAny .from_error )
306
337
307
- if key in ctx .type .required_keys :
308
- ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
309
- elif key not in ctx .type .items :
310
- ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
338
+ for key in keys :
339
+ if key in ctx .type .required_keys :
340
+ ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
341
+ elif key not in ctx .type .items :
342
+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
311
343
return ctx .default_return_type
312
344
313
345
0 commit comments