Skip to content

Commit cba07d7

Browse files
authored
Support recursive TypedDicts (#13373)
This is a continuation of #13297 Depends on #13371 It was actually quite easy, essentially just a 1-to-1 mapping from the other PR.
1 parent 601802c commit cba07d7

11 files changed

+354
-48
lines changed

mypy/fixup.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,17 @@ def lookup_fully_qualified_alias(
339339
if isinstance(node, TypeAlias):
340340
return node
341341
elif isinstance(node, TypeInfo):
342-
if node.tuple_alias:
343-
return node.tuple_alias
344-
alias = TypeAlias.from_tuple_type(node)
345-
node.tuple_alias = alias
342+
if node.special_alias:
343+
# Already fixed up.
344+
return node.special_alias
345+
if node.tuple_type:
346+
alias = TypeAlias.from_tuple_type(node)
347+
elif node.typeddict_type:
348+
alias = TypeAlias.from_typeddict_type(node)
349+
else:
350+
assert allow_missing
351+
return missing_alias()
352+
node.special_alias = alias
346353
return alias
347354
else:
348355
# Looks like a missing TypeAlias during an initial daemon load, put something there

mypy/nodes.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2656,7 +2656,7 @@ class is generic then it will be a type constructor of higher kind.
26562656
"bases",
26572657
"_promote",
26582658
"tuple_type",
2659-
"tuple_alias",
2659+
"special_alias",
26602660
"is_named_tuple",
26612661
"typeddict_type",
26622662
"is_newtype",
@@ -2795,8 +2795,16 @@ class is generic then it will be a type constructor of higher kind.
27952795
# It is useful for plugins to add their data to save in the cache.
27962796
metadata: Dict[str, JsonDict]
27972797

2798-
# Store type alias representing this type (for named tuples).
2799-
tuple_alias: Optional["TypeAlias"]
2798+
# Store type alias representing this type (for named tuples and TypedDicts).
2799+
# Although definitions of these types are stored in symbol tables as TypeInfo,
2800+
# when a type analyzer will find them, it should construct a TupleType, or
2801+
# a TypedDict type. However, we can't use the plain types, since if the definition
2802+
# is recursive, this will create an actual recursive structure of types (i.e. as
2803+
# internal Python objects) causing infinite recursions everywhere during type checking.
2804+
# To overcome this, we create a TypeAlias node, that will point to these types.
2805+
# We store this node in the `special_alias` attribute, because it must be the same node
2806+
# in case we are doing multiple semantic analysis passes.
2807+
special_alias: Optional["TypeAlias"]
28002808

28012809
FLAGS: Final = [
28022810
"is_abstract",
@@ -2844,7 +2852,7 @@ def __init__(self, names: "SymbolTable", defn: ClassDef, module_name: str) -> No
28442852
self._promote = []
28452853
self.alt_promote = None
28462854
self.tuple_type = None
2847-
self.tuple_alias = None
2855+
self.special_alias = None
28482856
self.is_named_tuple = False
28492857
self.typeddict_type = None
28502858
self.is_newtype = False
@@ -2976,13 +2984,22 @@ def direct_base_classes(self) -> "List[TypeInfo]":
29762984
return [base.type for base in self.bases]
29772985

29782986
def update_tuple_type(self, typ: "mypy.types.TupleType") -> None:
2979-
"""Update tuple_type and tuple_alias as needed."""
2987+
"""Update tuple_type and special_alias as needed."""
29802988
self.tuple_type = typ
29812989
alias = TypeAlias.from_tuple_type(self)
2982-
if not self.tuple_alias:
2983-
self.tuple_alias = alias
2990+
if not self.special_alias:
2991+
self.special_alias = alias
29842992
else:
2985-
self.tuple_alias.target = alias.target
2993+
self.special_alias.target = alias.target
2994+
2995+
def update_typeddict_type(self, typ: "mypy.types.TypedDictType") -> None:
2996+
"""Update typeddict_type and special_alias as needed."""
2997+
self.typeddict_type = typ
2998+
alias = TypeAlias.from_typeddict_type(self)
2999+
if not self.special_alias:
3000+
self.special_alias = alias
3001+
else:
3002+
self.special_alias.target = alias.target
29863003

29873004
def __str__(self) -> str:
29883005
"""Return a string representation of the type.
@@ -3283,6 +3300,17 @@ def from_tuple_type(cls, info: TypeInfo) -> "TypeAlias":
32833300
info.column,
32843301
)
32853302

3303+
@classmethod
3304+
def from_typeddict_type(cls, info: TypeInfo) -> "TypeAlias":
3305+
"""Generate an alias to the TypedDict type described by a given TypeInfo."""
3306+
assert info.typeddict_type
3307+
return TypeAlias(
3308+
info.typeddict_type.copy_modified(fallback=mypy.types.Instance(info, [])),
3309+
info.fullname,
3310+
info.line,
3311+
info.column,
3312+
)
3313+
32863314
@property
32873315
def name(self) -> str:
32883316
return self._fullname.split(".")[-1]

mypy/semanal.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,17 +1378,7 @@ def analyze_class(self, defn: ClassDef) -> None:
13781378
self.mark_incomplete(defn.name, defn)
13791379
return
13801380

1381-
is_typeddict, info = self.typed_dict_analyzer.analyze_typeddict_classdef(defn)
1382-
if is_typeddict:
1383-
for decorator in defn.decorators:
1384-
decorator.accept(self)
1385-
if isinstance(decorator, RefExpr):
1386-
if decorator.fullname in FINAL_DECORATOR_NAMES:
1387-
self.fail("@final cannot be used with TypedDict", decorator)
1388-
if info is None:
1389-
self.mark_incomplete(defn.name, defn)
1390-
else:
1391-
self.prepare_class_def(defn, info)
1381+
if self.analyze_typeddict_classdef(defn):
13921382
return
13931383

13941384
if self.analyze_namedtuple_classdef(defn):
@@ -1423,6 +1413,28 @@ def analyze_class_body_common(self, defn: ClassDef) -> None:
14231413
self.apply_class_plugin_hooks(defn)
14241414
self.leave_class()
14251415

1416+
def analyze_typeddict_classdef(self, defn: ClassDef) -> bool:
1417+
if (
1418+
defn.info
1419+
and defn.info.typeddict_type
1420+
and not has_placeholder(defn.info.typeddict_type)
1421+
):
1422+
# This is a valid TypedDict, and it is fully analyzed.
1423+
return True
1424+
is_typeddict, info = self.typed_dict_analyzer.analyze_typeddict_classdef(defn)
1425+
if is_typeddict:
1426+
for decorator in defn.decorators:
1427+
decorator.accept(self)
1428+
if isinstance(decorator, RefExpr):
1429+
if decorator.fullname in FINAL_DECORATOR_NAMES:
1430+
self.fail("@final cannot be used with TypedDict", decorator)
1431+
if info is None:
1432+
self.mark_incomplete(defn.name, defn)
1433+
else:
1434+
self.prepare_class_def(defn, info)
1435+
return True
1436+
return False
1437+
14261438
def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
14271439
"""Check if this class can define a named tuple."""
14281440
if (
@@ -1840,7 +1852,7 @@ def configure_tuple_base_class(self, defn: ClassDef, base: TupleType) -> Instanc
18401852
if info.tuple_type and info.tuple_type != base and not has_placeholder(info.tuple_type):
18411853
self.fail("Class has two incompatible bases derived from tuple", defn)
18421854
defn.has_incompatible_baseclass = True
1843-
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
1855+
if info.special_alias and has_placeholder(info.special_alias.target):
18441856
self.defer(force_progress=True)
18451857
info.update_tuple_type(base)
18461858

@@ -2660,7 +2672,11 @@ def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
26602672
def analyze_typeddict_assign(self, s: AssignmentStmt) -> bool:
26612673
"""Check if s defines a typed dict."""
26622674
if isinstance(s.rvalue, CallExpr) and isinstance(s.rvalue.analyzed, TypedDictExpr):
2663-
return True # This is a valid and analyzed typed dict definition, nothing to do here.
2675+
if s.rvalue.analyzed.info.typeddict_type and not has_placeholder(
2676+
s.rvalue.analyzed.info.typeddict_type
2677+
):
2678+
# This is a valid and analyzed typed dict definition, nothing to do here.
2679+
return True
26642680
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], (NameExpr, MemberExpr)):
26652681
return False
26662682
lvalue = s.lvalues[0]
@@ -5504,6 +5520,11 @@ def defer(self, debug_context: Optional[Context] = None, force_progress: bool =
55045520
"""
55055521
assert not self.final_iteration, "Must not defer during final iteration"
55065522
if force_progress:
5523+
# Usually, we report progress if we have replaced a placeholder node
5524+
# with an actual valid node. However, sometimes we need to update an
5525+
# existing node *in-place*. For example, this is used by type aliases
5526+
# in context of forward references and/or recursive aliases, and in
5527+
# similar situations (recursive named tuples etc).
55075528
self.progress = True
55085529
self.deferred = True
55095530
# Store debug info for this deferral.

mypy/semanal_namedtuple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def build_namedtuple_typeinfo(
478478
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
479479
info.is_named_tuple = True
480480
tuple_base = TupleType(types, fallback)
481-
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
481+
if info.special_alias and has_placeholder(info.special_alias.target):
482482
self.api.defer(force_progress=True)
483483
info.update_tuple_type(tuple_base)
484484
info.line = line

mypy/semanal_newtype.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> bool:
7979

8080
old_type, should_defer = self.check_newtype_args(var_name, call, s)
8181
old_type = get_proper_type(old_type)
82-
if not call.analyzed:
82+
if not isinstance(call.analyzed, NewTypeExpr):
8383
call.analyzed = NewTypeExpr(var_name, old_type, line=call.line, column=call.column)
84+
else:
85+
call.analyzed.old_type = old_type
8486
if old_type is None:
8587
if should_defer:
8688
# Base type is not ready.
@@ -230,6 +232,7 @@ def build_newtype_typeinfo(
230232
existing_info: Optional[TypeInfo],
231233
) -> TypeInfo:
232234
info = existing_info or self.api.basic_new_typeinfo(name, base_type, line)
235+
info.bases = [base_type] # Update in case there were nested placeholders.
233236
info.is_newtype = True
234237

235238
# Add __init__ method
@@ -250,7 +253,7 @@ def build_newtype_typeinfo(
250253
init_func._fullname = info.fullname + ".__init__"
251254
info.names["__init__"] = SymbolTableNode(MDEF, init_func)
252255

253-
if info.tuple_type and has_placeholder(info.tuple_type):
256+
if has_placeholder(old_type) or info.tuple_type and has_placeholder(info.tuple_type):
254257
self.api.defer(force_progress=True)
255258
return info
256259

mypy/semanal_typeddict.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
TypeInfo,
2828
)
2929
from mypy.options import Options
30-
from mypy.semanal_shared import SemanticAnalyzerInterface
30+
from mypy.semanal_shared import SemanticAnalyzerInterface, has_placeholder
3131
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type
3232
from mypy.types import TPDICT_NAMES, AnyType, RequiredType, Type, TypedDictType, TypeOfAny
3333

@@ -66,6 +66,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
6666
if base_expr.fullname in TPDICT_NAMES or self.is_typeddict(base_expr):
6767
possible = True
6868
if possible:
69+
existing_info = None
70+
if isinstance(defn.analyzed, TypedDictExpr):
71+
existing_info = defn.analyzed.info
6972
if (
7073
len(defn.base_type_exprs) == 1
7174
and isinstance(defn.base_type_exprs[0], RefExpr)
@@ -76,7 +79,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
7679
if fields is None:
7780
return True, None # Defer
7881
info = self.build_typeddict_typeinfo(
79-
defn.name, fields, types, required_keys, defn.line
82+
defn.name, fields, types, required_keys, defn.line, existing_info
8083
)
8184
defn.analyzed = TypedDictExpr(info)
8285
defn.analyzed.line = defn.line
@@ -128,7 +131,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
128131
keys.extend(new_keys)
129132
types.extend(new_types)
130133
required_keys.update(new_required_keys)
131-
info = self.build_typeddict_typeinfo(defn.name, keys, types, required_keys, defn.line)
134+
info = self.build_typeddict_typeinfo(
135+
defn.name, keys, types, required_keys, defn.line, existing_info
136+
)
132137
defn.analyzed = TypedDictExpr(info)
133138
defn.analyzed.line = defn.line
134139
defn.analyzed.column = defn.column
@@ -173,7 +178,12 @@ def analyze_typeddict_classdef_fields(
173178
if stmt.type is None:
174179
types.append(AnyType(TypeOfAny.unannotated))
175180
else:
176-
analyzed = self.api.anal_type(stmt.type, allow_required=True)
181+
analyzed = self.api.anal_type(
182+
stmt.type,
183+
allow_required=True,
184+
allow_placeholder=self.options.enable_recursive_aliases
185+
and not self.api.is_func_scope(),
186+
)
177187
if analyzed is None:
178188
return None, [], set() # Need to defer
179189
types.append(analyzed)
@@ -232,7 +242,7 @@ def check_typeddict(
232242
name, items, types, total, ok = res
233243
if not ok:
234244
# Error. Construct dummy return value.
235-
info = self.build_typeddict_typeinfo("TypedDict", [], [], set(), call.line)
245+
info = self.build_typeddict_typeinfo("TypedDict", [], [], set(), call.line, None)
236246
else:
237247
if var_name is not None and name != var_name:
238248
self.fail(
@@ -254,7 +264,12 @@ def check_typeddict(
254264
types = [ # unwrap Required[T] to just T
255265
t.item if isinstance(t, RequiredType) else t for t in types # type: ignore[misc]
256266
]
257-
info = self.build_typeddict_typeinfo(name, items, types, required_keys, call.line)
267+
existing_info = None
268+
if isinstance(node.analyzed, TypedDictExpr):
269+
existing_info = node.analyzed.info
270+
info = self.build_typeddict_typeinfo(
271+
name, items, types, required_keys, call.line, existing_info
272+
)
258273
info.line = node.line
259274
# Store generated TypeInfo under both names, see semanal_namedtuple for more details.
260275
if name != var_name or is_func_scope:
@@ -357,7 +372,12 @@ def parse_typeddict_fields_with_types(
357372
else:
358373
self.fail_typeddict_arg("Invalid field type", field_type_expr)
359374
return [], [], False
360-
analyzed = self.api.anal_type(type, allow_required=True)
375+
analyzed = self.api.anal_type(
376+
type,
377+
allow_required=True,
378+
allow_placeholder=self.options.enable_recursive_aliases
379+
and not self.api.is_func_scope(),
380+
)
361381
if analyzed is None:
362382
return None
363383
types.append(analyzed)
@@ -370,7 +390,13 @@ def fail_typeddict_arg(
370390
return "", [], [], True, False
371391

372392
def build_typeddict_typeinfo(
373-
self, name: str, items: List[str], types: List[Type], required_keys: Set[str], line: int
393+
self,
394+
name: str,
395+
items: List[str],
396+
types: List[Type],
397+
required_keys: Set[str],
398+
line: int,
399+
existing_info: Optional[TypeInfo],
374400
) -> TypeInfo:
375401
# Prefer typing then typing_extensions if available.
376402
fallback = (
@@ -379,8 +405,11 @@ def build_typeddict_typeinfo(
379405
or self.api.named_type_or_none("mypy_extensions._TypedDict", [])
380406
)
381407
assert fallback is not None
382-
info = self.api.basic_new_typeinfo(name, fallback, line)
383-
info.typeddict_type = TypedDictType(dict(zip(items, types)), required_keys, fallback)
408+
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
409+
typeddict_type = TypedDictType(dict(zip(items, types)), required_keys, fallback)
410+
if info.special_alias and has_placeholder(info.special_alias.target):
411+
self.api.defer(force_progress=True)
412+
info.update_typeddict_type(typeddict_type)
384413
return info
385414

386415
# Helpers

mypy/server/astmerge.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ def replacement_map_from_symbol_table(
172172
node.node.names, new_node.node.names, prefix
173173
)
174174
replacements.update(type_repl)
175-
if node.node.tuple_alias and new_node.node.tuple_alias:
176-
replacements[new_node.node.tuple_alias] = node.node.tuple_alias
175+
if node.node.special_alias and new_node.node.special_alias:
176+
replacements[new_node.node.special_alias] = node.node.special_alias
177177
return replacements
178178

179179

@@ -338,10 +338,10 @@ def fixup(self, node: SN) -> SN:
338338
new = self.replacements[node]
339339
skip_slots: Tuple[str, ...] = ()
340340
if isinstance(node, TypeInfo) and isinstance(new, TypeInfo):
341-
# Special case: tuple_alias is not exposed in symbol tables, but may appear
341+
# Special case: special_alias is not exposed in symbol tables, but may appear
342342
# in external types (e.g. named tuples), so we need to update it manually.
343-
skip_slots = ("tuple_alias",)
344-
replace_object_state(new.tuple_alias, node.tuple_alias)
343+
skip_slots = ("special_alias",)
344+
replace_object_state(new.special_alias, node.special_alias)
345345
replace_object_state(new, node, skip_slots=skip_slots)
346346
return cast(SN, new)
347347
return node
@@ -372,8 +372,8 @@ def process_type_info(self, info: Optional[TypeInfo]) -> None:
372372
self.fixup_type(target)
373373
self.fixup_type(info.tuple_type)
374374
self.fixup_type(info.typeddict_type)
375-
if info.tuple_alias:
376-
self.fixup_type(info.tuple_alias.target)
375+
if info.special_alias:
376+
self.fixup_type(info.special_alias.target)
377377
info.defn.info = self.fixup(info)
378378
replace_nodes_in_symbol_table(info.names, self.replacements)
379379
for i, item in enumerate(info.mro):
@@ -547,7 +547,7 @@ def replace_nodes_in_symbol_table(
547547
new = replacements[node.node]
548548
old = node.node
549549
# Needed for TypeInfo, see comment in fixup() above.
550-
replace_object_state(new, old, skip_slots=("tuple_alias",))
550+
replace_object_state(new, old, skip_slots=("special_alias",))
551551
node.node = new
552552
if isinstance(node.node, (Var, TypeAlias)):
553553
# Handle them here just in case these aren't exposed through the AST.

0 commit comments

Comments
 (0)