Skip to content

Commit 7fd6eba

Browse files
davidfstrJukkaL
authored andcommitted
TypedDict: Recognize creation of TypedDict instance. Define TypedDictType. (#2342)
Notable visitor implementations added: * Subtype * Join * Meet * Constraint Solve Also: * Fix support for using dict(...) in TypedDict instance constructor. * Allow instantiation of empty TypedDict. * Disallow underscore prefix on TypedDict item names. * TypeAnalyser: Resolve an unbound reference to a typeddict as a named TypedDictType rather than as an Instance.
1 parent fe1d523 commit 7fd6eba

23 files changed

+903
-69
lines changed

mypy/checkexpr.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Expression type checker. This file is conceptually part of TypeChecker."""
22

3+
from collections import OrderedDict
34
from typing import cast, Dict, Set, List, Iterable, Tuple, Callable, Union, Optional
45

56
from mypy.types import (
67
Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef,
7-
TupleType, Instance, TypeVarId, TypeVarType, ErasedType, UnionType,
8+
TupleType, TypedDictType, Instance, TypeVarId, TypeVarType, ErasedType, UnionType,
89
PartialType, DeletedType, UnboundType, UninhabitedType, TypeType,
910
true_only, false_only, is_named_instance, function_type,
1011
get_typ_args, set_typ_args,
@@ -169,6 +170,10 @@ def visit_call_expr(self, e: CallExpr) -> Type:
169170
if e.analyzed:
170171
# It's really a special form that only looks like a call.
171172
return self.accept(e.analyzed, self.chk.type_context[-1])
173+
if isinstance(e.callee, NameExpr) and isinstance(e.callee.node, TypeInfo) and \
174+
e.callee.node.typeddict_type is not None:
175+
return self.check_typeddict_call(e.callee.node.typeddict_type,
176+
e.arg_kinds, e.arg_names, e.args, e)
172177
self.try_infer_partial_type(e)
173178
callee_type = self.accept(e.callee)
174179
if (self.chk.options.disallow_untyped_calls and
@@ -178,6 +183,80 @@ def visit_call_expr(self, e: CallExpr) -> Type:
178183
return self.msg.untyped_function_call(callee_type, e)
179184
return self.check_call_expr_with_callee_type(callee_type, e)
180185

186+
def check_typeddict_call(self, callee: TypedDictType,
187+
arg_kinds: List[int],
188+
arg_names: List[str],
189+
args: List[Expression],
190+
context: Context) -> Type:
191+
if len(args) >= 1 and all([ak == ARG_NAMED for ak in arg_kinds]):
192+
# ex: Point(x=42, y=1337)
193+
item_names = arg_names
194+
item_args = args
195+
return self.check_typeddict_call_with_kwargs(
196+
callee, OrderedDict(zip(item_names, item_args)), context)
197+
198+
if len(args) == 1 and arg_kinds[0] == ARG_POS:
199+
unique_arg = args[0]
200+
if isinstance(unique_arg, DictExpr):
201+
# ex: Point({'x': 42, 'y': 1337})
202+
return self.check_typeddict_call_with_dict(callee, unique_arg, context)
203+
if isinstance(unique_arg, CallExpr) and isinstance(unique_arg.analyzed, DictExpr):
204+
# ex: Point(dict(x=42, y=1337))
205+
return self.check_typeddict_call_with_dict(callee, unique_arg.analyzed, context)
206+
207+
if len(args) == 0:
208+
# ex: EmptyDict()
209+
return self.check_typeddict_call_with_kwargs(
210+
callee, OrderedDict(), context)
211+
212+
self.chk.fail(messages.INVALID_TYPEDDICT_ARGS, context)
213+
return AnyType()
214+
215+
def check_typeddict_call_with_dict(self, callee: TypedDictType,
216+
kwargs: DictExpr,
217+
context: Context) -> Type:
218+
item_name_exprs = [item[0] for item in kwargs.items]
219+
item_args = [item[1] for item in kwargs.items]
220+
221+
item_names = [] # List[str]
222+
for item_name_expr in item_name_exprs:
223+
if not isinstance(item_name_expr, StrExpr):
224+
self.chk.fail(messages.TYPEDDICT_ITEM_NAME_MUST_BE_STRING_LITERAL, item_name_expr)
225+
return AnyType()
226+
item_names.append(item_name_expr.value)
227+
228+
return self.check_typeddict_call_with_kwargs(
229+
callee, OrderedDict(zip(item_names, item_args)), context)
230+
231+
def check_typeddict_call_with_kwargs(self, callee: TypedDictType,
232+
kwargs: 'OrderedDict[str, Expression]',
233+
context: Context) -> Type:
234+
if callee.items.keys() != kwargs.keys():
235+
callee_item_names = callee.items.keys()
236+
kwargs_item_names = kwargs.keys()
237+
238+
self.msg.typeddict_instantiated_with_unexpected_items(
239+
expected_item_names=list(callee_item_names),
240+
actual_item_names=list(kwargs_item_names),
241+
context=context)
242+
return AnyType()
243+
244+
items = OrderedDict() # type: OrderedDict[str, Type]
245+
for (item_name, item_expected_type) in callee.items.items():
246+
item_value = kwargs[item_name]
247+
248+
item_actual_type = self.chk.check_simple_assignment(
249+
lvalue_type=item_expected_type, rvalue=item_value, context=item_value,
250+
msg=messages.INCOMPATIBLE_TYPES,
251+
lvalue_name='TypedDict item "{}"'.format(item_name),
252+
rvalue_name='expression')
253+
items[item_name] = item_actual_type
254+
255+
mapping_value_type = join.join_type_list(list(items.values()))
256+
fallback = self.chk.named_generic_type('typing.Mapping',
257+
[self.chk.str_type(), mapping_value_type])
258+
return TypedDictType(items, fallback)
259+
181260
# Types and methods that can be used to infer partial types.
182261
item_args = {'builtins.list': ['append'],
183262
'builtins.set': ['add', 'discard'],

mypy/checkmember.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import cast, Callable, List, Optional, TypeVar
44

55
from mypy.types import (
6-
Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarDef,
6+
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef,
77
Overloaded, TypeVarType, UnionType, PartialType,
88
DeletedType, NoneTyp, TypeType, function_type
99
)
@@ -116,6 +116,11 @@ def analyze_member_access(name: str,
116116
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
117117
is_operator, builtin_type, not_ready_callback, msg,
118118
original_type=original_type, chk=chk)
119+
elif isinstance(typ, TypedDictType):
120+
# Actually look up from the fallback instance type.
121+
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
122+
is_operator, builtin_type, not_ready_callback, msg,
123+
original_type=original_type, chk=chk)
119124
elif isinstance(typ, FunctionLike) and typ.is_type_obj():
120125
# Class attribute.
121126
# TODO super?

mypy/constraints.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Type inference constraints."""
22

3-
from typing import List, Optional
3+
from typing import Iterable, List, Optional
44

55
from mypy.types import (
66
CallableType, Type, TypeVisitor, UnboundType, AnyType, Void, NoneTyp, TypeVarType,
7-
Instance, TupleType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
8-
UninhabitedType, TypeType, TypeVarId, is_named_instance
7+
Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType,
8+
DeletedType, UninhabitedType, TypeType, TypeVarId, is_named_instance
99
)
1010
from mypy.maptype import map_instance_to_supertype
1111
from mypy import nodes
@@ -342,11 +342,27 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
342342
else:
343343
return []
344344

345+
def visit_typeddict_type(self, template: TypedDictType) -> List[Constraint]:
346+
actual = self.actual
347+
if isinstance(actual, TypedDictType):
348+
res = [] # type: List[Constraint]
349+
# NOTE: Non-matching keys are ignored. Compatibility is checked
350+
# elsewhere so this shouldn't be unsafe.
351+
for (item_name, template_item_type, actual_item_type) in template.zip(actual):
352+
res.extend(infer_constraints(template_item_type,
353+
actual_item_type,
354+
self.direction))
355+
return res
356+
elif isinstance(actual, AnyType):
357+
return self.infer_against_any(template.items.values())
358+
else:
359+
return []
360+
345361
def visit_union_type(self, template: UnionType) -> List[Constraint]:
346362
assert False, ("Unexpected UnionType in ConstraintBuilderVisitor"
347363
" (should have been handled in infer_constraints)")
348364

349-
def infer_against_any(self, types: List[Type]) -> List[Constraint]:
365+
def infer_against_any(self, types: Iterable[Type]) -> List[Constraint]:
350366
res = [] # type: List[Constraint]
351367
for t in types:
352368
res.extend(infer_constraints(t, AnyType(), self.direction))

mypy/erasetype.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from mypy.types import (
44
Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp, TypeVarId,
5-
Instance, TypeVarType, CallableType, TupleType, UnionType, Overloaded, ErasedType,
6-
PartialType, DeletedType, TypeTranslator, TypeList, UninhabitedType, TypeType
5+
Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded,
6+
ErasedType, PartialType, DeletedType, TypeTranslator, TypeList, UninhabitedType, TypeType
77
)
88
from mypy import experiments
99

@@ -78,6 +78,9 @@ def visit_overloaded(self, t: Overloaded) -> Type:
7878
def visit_tuple_type(self, t: TupleType) -> Type:
7979
return t.fallback.accept(self)
8080

81+
def visit_typeddict_type(self, t: TypedDictType) -> Type:
82+
return t.fallback.accept(self)
83+
8184
def visit_union_type(self, t: UnionType) -> Type:
8285
return AnyType() # XXX: return underlying type if only one?
8386

mypy/expandtype.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Dict, List
1+
from typing import Dict, Iterable, List
22

33
from mypy.types import (
44
Type, Instance, CallableType, TypeVisitor, UnboundType, ErrorType, AnyType,
5-
Void, NoneTyp, TypeVarType, Overloaded, TupleType, UnionType, ErasedType, TypeList,
6-
PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId
5+
Void, NoneTyp, TypeVarType, Overloaded, TupleType, TypedDictType, UnionType,
6+
ErasedType, TypeList, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId
77
)
88

99

@@ -93,6 +93,9 @@ def visit_overloaded(self, t: Overloaded) -> Type:
9393
def visit_tuple_type(self, t: TupleType) -> Type:
9494
return t.copy_modified(items=self.expand_types(t.items))
9595

96+
def visit_typeddict_type(self, t: TypedDictType) -> Type:
97+
return t.copy_modified(item_types=self.expand_types(t.items.values()))
98+
9699
def visit_union_type(self, t: UnionType) -> Type:
97100
# After substituting for type variables in t.items,
98101
# some of the resulting types might be subtypes of others.
@@ -108,7 +111,7 @@ def visit_type_type(self, t: TypeType) -> Type:
108111
item = t.item.accept(self)
109112
return TypeType(item)
110113

111-
def expand_types(self, types: List[Type]) -> List[Type]:
114+
def expand_types(self, types: Iterable[Type]) -> List[Type]:
112115
a = [] # type: List[Type]
113116
for t in types:
114117
a.append(t.accept(self))

mypy/fixup.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
from typing import Any, Dict, Optional
44

5-
from mypy.nodes import (MypyFile, SymbolNode, SymbolTable, SymbolTableNode,
6-
TypeInfo, FuncDef, OverloadedFuncDef, Decorator, Var,
7-
TypeVarExpr, ClassDef,
8-
LDEF, MDEF, GDEF)
9-
from mypy.types import (CallableType, EllipsisType, Instance, Overloaded, TupleType,
10-
TypeList, TypeVarType, UnboundType, UnionType, TypeVisitor,
11-
TypeType)
5+
from mypy.nodes import (
6+
MypyFile, SymbolNode, SymbolTable, SymbolTableNode,
7+
TypeInfo, FuncDef, OverloadedFuncDef, Decorator, Var,
8+
TypeVarExpr, ClassDef,
9+
LDEF, MDEF, GDEF
10+
)
11+
from mypy.types import (
12+
CallableType, EllipsisType, Instance, Overloaded, TupleType, TypedDictType,
13+
TypeList, TypeVarType, UnboundType, UnionType, TypeVisitor,
14+
TypeType
15+
)
1216
from mypy.visitor import NodeVisitor
1317

1418

@@ -192,6 +196,13 @@ def visit_tuple_type(self, tt: TupleType) -> None:
192196
if tt.fallback is not None:
193197
tt.fallback.accept(self)
194198

199+
def visit_typeddict_type(self, tdt: TypedDictType) -> None:
200+
if tdt.items:
201+
for it in tdt.items.values():
202+
it.accept(self)
203+
if tdt.fallback is not None:
204+
tdt.fallback.accept(self)
205+
195206
def visit_type_list(self, tl: TypeList) -> None:
196207
for t in tl.items:
197208
t.accept(self)

mypy/indirection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def visit_overloaded(self, t: types.Overloaded) -> Set[str]:
8787
def visit_tuple_type(self, t: types.TupleType) -> Set[str]:
8888
return self._visit(*t.items) | self._visit(t.fallback)
8989

90+
def visit_typeddict_type(self, t: types.TypedDictType) -> Set[str]:
91+
return self._visit(*t.items.values()) | self._visit(t.fallback)
92+
9093
def visit_star_type(self, t: types.StarType) -> Set[str]:
9194
return set()
9295

mypy/join.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Calculation of the least upper bound types (joins)."""
22

3-
from typing import List
3+
from collections import OrderedDict
4+
from typing import cast, List
45

56
from mypy.types import (
67
Type, AnyType, NoneTyp, Void, TypeVisitor, Instance, UnboundType,
7-
ErrorType, TypeVarType, CallableType, TupleType, ErasedType, TypeList,
8+
ErrorType, TypeVarType, CallableType, TupleType, TypedDictType, ErasedType, TypeList,
89
UnionType, FunctionLike, Overloaded, PartialType, DeletedType,
910
UninhabitedType, TypeType, true_or_false
1011
)
@@ -170,6 +171,8 @@ def visit_instance(self, t: Instance) -> Type:
170171
return join_types(t, self.s.fallback)
171172
elif isinstance(self.s, TypeType):
172173
return join_types(t, self.s)
174+
elif isinstance(self.s, TypedDictType):
175+
return join_types(t, self.s)
173176
else:
174177
return self.default(self.s)
175178

@@ -234,13 +237,27 @@ def visit_tuple_type(self, t: TupleType) -> Type:
234237
items = [] # type: List[Type]
235238
for i in range(t.length()):
236239
items.append(self.join(t.items[i], self.s.items[i]))
237-
# join fallback types if they are different
238240
fallback = join_instances(self.s.fallback, t.fallback)
239241
assert isinstance(fallback, Instance)
240242
return TupleType(items, fallback)
241243
else:
242244
return self.default(self.s)
243245

246+
def visit_typeddict_type(self, t: TypedDictType) -> Type:
247+
if isinstance(self.s, TypedDictType):
248+
items = OrderedDict([
249+
(item_name, s_item_type)
250+
for (item_name, s_item_type, t_item_type) in self.s.zip(t)
251+
if is_equivalent(s_item_type, t_item_type)
252+
])
253+
mapping_value_type = join_type_list(list(items.values()))
254+
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
255+
return TypedDictType(items, fallback)
256+
elif isinstance(self.s, Instance):
257+
return join_instances(self.s, t.fallback)
258+
else:
259+
return self.default(self.s)
260+
244261
def visit_partial_type(self, t: PartialType) -> Type:
245262
# We only have partial information so we can't decide the join result. We should
246263
# never get here.
@@ -266,6 +283,8 @@ def default(self, typ: Type) -> Type:
266283
return ErrorType()
267284
elif isinstance(typ, TupleType):
268285
return self.default(typ.fallback)
286+
elif isinstance(typ, TypedDictType):
287+
return self.default(typ.fallback)
269288
elif isinstance(typ, FunctionLike):
270289
return self.default(typ.fallback)
271290
elif isinstance(typ, TypeVarType):

mypy/meet.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from typing import List
1+
from collections import OrderedDict
2+
from typing import List, Optional
23

3-
from mypy.join import is_similar_callables, combine_similar_callables
4+
from mypy.join import is_similar_callables, combine_similar_callables, join_type_list
45
from mypy.types import (
56
Type, AnyType, TypeVisitor, UnboundType, Void, ErrorType, NoneTyp, TypeVarType,
6-
Instance, CallableType, TupleType, ErasedType, TypeList, UnionType, PartialType,
7+
Instance, CallableType, TupleType, TypedDictType, ErasedType, TypeList, UnionType, PartialType,
78
DeletedType, UninhabitedType, TypeType
89
)
9-
from mypy.subtypes import is_subtype
10+
from mypy.subtypes import is_equivalent, is_subtype
1011

1112
from mypy import experiments
1213

@@ -252,6 +253,21 @@ def visit_tuple_type(self, t: TupleType) -> Type:
252253
else:
253254
return self.default(self.s)
254255

256+
def visit_typeddict_type(self, t: TypedDictType) -> Type:
257+
if isinstance(self.s, TypedDictType):
258+
for (_, l, r) in self.s.zip(t):
259+
if not is_equivalent(l, r):
260+
return self.default(self.s)
261+
items = OrderedDict([
262+
(item_name, s_item_type or t_item_type)
263+
for (item_name, s_item_type, t_item_type) in self.s.zipall(t)
264+
])
265+
mapping_value_type = join_type_list(list(items.values()))
266+
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
267+
return TypedDictType(items, fallback)
268+
else:
269+
return self.default(self.s)
270+
255271
def visit_partial_type(self, t: PartialType) -> Type:
256272
# We can't determine the meet of partial types. We should never get here.
257273
assert False, 'Internal error'

0 commit comments

Comments
 (0)