Skip to content

Commit 0820e95

Browse files
authored
[PEP 695] Support recursive type aliases (#17268)
The implementation follows the approach used for old-style type aliases. Work on #15238.
1 parent 7032f8c commit 0820e95

File tree

3 files changed

+74
-5
lines changed

3 files changed

+74
-5
lines changed

mypy/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1647,19 +1647,21 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
16471647

16481648

16491649
class TypeAliasStmt(Statement):
1650-
__slots__ = ("name", "type_args", "value")
1650+
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias")
16511651

16521652
__match_args__ = ("name", "type_args", "value")
16531653

16541654
name: NameExpr
16551655
type_args: list[TypeParam]
16561656
value: Expression # Will get translated into a type
1657+
invalid_recursive_alias: bool
16571658

16581659
def __init__(self, name: NameExpr, type_args: list[TypeParam], value: Expression) -> None:
16591660
super().__init__()
16601661
self.name = name
16611662
self.type_args = type_args
16621663
self.value = value
1664+
self.invalid_recursive_alias = False
16631665

16641666
def accept(self, visitor: StatementVisitor[T]) -> T:
16651667
return visitor.visit_type_alias_stmt(self)

mypy/semanal.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3961,7 +3961,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
39613961
alias_node.normalized = rvalue.node.normalized
39623962
current_node = existing.node if existing else alias_node
39633963
assert isinstance(current_node, TypeAlias)
3964-
self.disable_invalid_recursive_aliases(s, current_node)
3964+
self.disable_invalid_recursive_aliases(s, current_node, s.rvalue)
39653965
if self.is_class_scope():
39663966
assert self.type is not None
39673967
if self.type.is_protocol:
@@ -4057,7 +4057,7 @@ def analyze_type_alias_type_params(
40574057
return declared_tvars, all_declared_tvar_names
40584058

40594059
def disable_invalid_recursive_aliases(
4060-
self, s: AssignmentStmt, current_node: TypeAlias
4060+
self, s: AssignmentStmt | TypeAliasStmt, current_node: TypeAlias, ctx: Context
40614061
) -> None:
40624062
"""Prohibit and fix recursive type aliases that are invalid/unsupported."""
40634063
messages = []
@@ -4074,7 +4074,7 @@ def disable_invalid_recursive_aliases(
40744074
current_node.target = AnyType(TypeOfAny.from_error)
40754075
s.invalid_recursive_alias = True
40764076
for msg in messages:
4077-
self.fail(msg, s.rvalue)
4077+
self.fail(msg, ctx)
40784078

40794079
def analyze_lvalue(
40804080
self,
@@ -5304,6 +5304,8 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
53045304
self.visit_block(s.bodies[i])
53055305

53065306
def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
5307+
if s.invalid_recursive_alias:
5308+
return
53075309
self.statement = s
53085310
type_params = self.push_type_args(s.type_args, s)
53095311
if type_params is None:
@@ -5369,10 +5371,32 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
53695371
and isinstance(existing.node, (PlaceholderNode, TypeAlias))
53705372
and existing.node.line == s.line
53715373
):
5372-
existing.node = alias_node
5374+
updated = False
5375+
if isinstance(existing.node, TypeAlias):
5376+
if existing.node.target != res:
5377+
# Copy expansion to the existing alias, this matches how we update base classes
5378+
# for a TypeInfo _in place_ if there are nested placeholders.
5379+
existing.node.target = res
5380+
existing.node.alias_tvars = alias_tvars
5381+
updated = True
5382+
else:
5383+
# Otherwise just replace existing placeholder with type alias.
5384+
existing.node = alias_node
5385+
updated = True
5386+
5387+
if updated:
5388+
if self.final_iteration:
5389+
self.cannot_resolve_name(s.name.name, "name", s)
5390+
return
5391+
else:
5392+
# We need to defer so that this change can get propagated to base classes.
5393+
self.defer(s, force_progress=True)
53735394
else:
53745395
self.add_symbol(s.name.name, alias_node, s)
53755396

5397+
current_node = existing.node if existing else alias_node
5398+
assert isinstance(current_node, TypeAlias)
5399+
self.disable_invalid_recursive_aliases(s, current_node, s.value)
53765400
finally:
53775401
self.pop_type_args(s.type_args)
53785402

test-data/unit/check-python312.test

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,49 @@ def decorator(x: str) -> Any: ...
11621162
class C[T]:
11631163
pass
11641164

1165+
[case testPEP695RecursiceTypeAlias]
1166+
# mypy: enable-incomplete-feature=NewGenericSyntax
1167+
1168+
type A = str | list[A]
1169+
a: A
1170+
reveal_type(a) # N: Revealed type is "Union[builtins.str, builtins.list[...]]"
1171+
1172+
class C[T]: pass
1173+
1174+
type B[T] = C[T] | list[B[T]]
1175+
b: B[int]
1176+
reveal_type(b) # N: Revealed type is "Union[__main__.C[builtins.int], builtins.list[...]]"
1177+
1178+
[case testPEP695BadRecursiveTypeAlias]
1179+
# mypy: enable-incomplete-feature=NewGenericSyntax
1180+
1181+
type A = A # E: Cannot resolve name "A" (possible cyclic definition)
1182+
type B = B | int # E: Invalid recursive alias: a union item of itself
1183+
a: A
1184+
reveal_type(a) # N: Revealed type is "Any"
1185+
b: B
1186+
reveal_type(b) # N: Revealed type is "Any"
1187+
1188+
[case testPEP695RecursiveTypeAliasForwardReference]
1189+
# mypy: enable-incomplete-feature=NewGenericSyntax
1190+
1191+
def f(a: A) -> None:
1192+
if isinstance(a, str):
1193+
reveal_type(a) # N: Revealed type is "builtins.str"
1194+
else:
1195+
reveal_type(a) # N: Revealed type is "__main__.C[Union[builtins.str, __main__.C[...]]]"
1196+
1197+
type A = str | C[A]
1198+
1199+
class C[T]: pass
1200+
1201+
f('x')
1202+
f(C[str]())
1203+
f(C[C[str]]())
1204+
f(1) # E: Argument 1 to "f" has incompatible type "int"; expected "A"
1205+
f(C[int]()) # E: Argument 1 to "f" has incompatible type "C[int]"; expected "A"
1206+
[builtins fixtures/isinstance.pyi]
1207+
11651208
[case testPEP695InvalidGenericOrProtocolBaseClass]
11661209
# mypy: enable-incomplete-feature=NewGenericSyntax
11671210
from typing import Generic, Protocol, TypeVar

0 commit comments

Comments
 (0)