Skip to content

Commit 608de81

Browse files
authored
Handle interactions between recursive aliases and recursive instances (#13328)
This is a follow-up for #13297 The fix for infinite recursion is kind of simple, but it is hard to make inference infer something useful. Currently we handle all most common cases, but it is quite fragile (I however have few tricks left if people will complain about inference).
1 parent b3eebe3 commit 608de81

File tree

8 files changed

+194
-55
lines changed

8 files changed

+194
-55
lines changed

mypy/checkexpr.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
is_optional,
155155
remove_optional,
156156
)
157+
from mypy.typestate import TypeState
157158
from mypy.typevars import fill_typevars
158159
from mypy.util import split_module_names
159160
from mypy.visitor import ExpressionVisitor
@@ -1429,6 +1430,22 @@ def infer_arg_types_in_empty_context(self, args: List[Expression]) -> List[Type]
14291430
res.append(arg_type)
14301431
return res
14311432

1433+
@contextmanager
1434+
def allow_unions(self, type_context: Type) -> Iterator[None]:
1435+
# This is a hack to better support inference for recursive types.
1436+
# When the outer context for a function call is known to be recursive,
1437+
# we solve type constraints inferred from arguments using unions instead
1438+
# of joins. This is a bit arbitrary, but in practice it works for most
1439+
# cases. A cleaner alternative would be to switch to single bin type
1440+
# inference, but this is a lot of work.
1441+
old = TypeState.infer_unions
1442+
if has_recursive_types(type_context):
1443+
TypeState.infer_unions = True
1444+
try:
1445+
yield
1446+
finally:
1447+
TypeState.infer_unions = old
1448+
14321449
def infer_arg_types_in_context(
14331450
self,
14341451
callee: CallableType,
@@ -1448,7 +1465,8 @@ def infer_arg_types_in_context(
14481465
for i, actuals in enumerate(formal_to_actual):
14491466
for ai in actuals:
14501467
if not arg_kinds[ai].is_star():
1451-
res[ai] = self.accept(args[ai], callee.arg_types[i])
1468+
with self.allow_unions(callee.arg_types[i]):
1469+
res[ai] = self.accept(args[ai], callee.arg_types[i])
14521470

14531471
# Fill in the rest of the argument types.
14541472
for i, t in enumerate(res):
@@ -1568,25 +1586,13 @@ def infer_function_type_arguments(
15681586
else:
15691587
pass1_args.append(arg)
15701588

1571-
# This is a hack to better support inference for recursive types.
1572-
# When the outer context for a function call is known to be recursive,
1573-
# we solve type constraints inferred from arguments using unions instead
1574-
# of joins. This is a bit arbitrary, but in practice it works for most
1575-
# cases. A cleaner alternative would be to switch to single bin type
1576-
# inference, but this is a lot of work.
1577-
ctx = self.type_context[-1]
1578-
if ctx and has_recursive_types(ctx):
1579-
infer_unions = True
1580-
else:
1581-
infer_unions = False
15821589
inferred_args = infer_function_type_arguments(
15831590
callee_type,
15841591
pass1_args,
15851592
arg_kinds,
15861593
formal_to_actual,
15871594
context=self.argument_infer_context(),
15881595
strict=self.chk.in_checked_function(),
1589-
infer_unions=infer_unions,
15901596
)
15911597

15921598
if 2 in arg_pass_nums:

mypy/constraints.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
UnpackType,
4343
callable_with_ellipsis,
4444
get_proper_type,
45+
has_recursive_types,
46+
has_type_vars,
4547
is_named_instance,
4648
is_union_with_any,
4749
)
@@ -141,14 +143,19 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> List[Cons
141143
The constraints are represented as Constraint objects.
142144
"""
143145
if any(
144-
get_proper_type(template) == get_proper_type(t) for t in reversed(TypeState._inferring)
146+
get_proper_type(template) == get_proper_type(t)
147+
and get_proper_type(actual) == get_proper_type(a)
148+
for (t, a) in reversed(TypeState.inferring)
145149
):
146150
return []
147-
if isinstance(template, TypeAliasType) and template.is_recursive:
151+
if has_recursive_types(template):
148152
# This case requires special care because it may cause infinite recursion.
149-
TypeState._inferring.append(template)
153+
if not has_type_vars(template):
154+
# Return early on an empty branch.
155+
return []
156+
TypeState.inferring.append((template, actual))
150157
res = _infer_constraints(template, actual, direction)
151-
TypeState._inferring.pop()
158+
TypeState.inferring.pop()
152159
return res
153160
return _infer_constraints(template, actual, direction)
154161

@@ -216,13 +223,18 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> List[Con
216223
# When the template is a union, we are okay with leaving some
217224
# type variables indeterminate. This helps with some special
218225
# cases, though this isn't very principled.
219-
return any_constraints(
226+
result = any_constraints(
220227
[
221228
infer_constraints_if_possible(t_item, actual, direction)
222229
for t_item in template.items
223230
],
224231
eager=False,
225232
)
233+
if result:
234+
return result
235+
elif has_recursive_types(template) and not has_recursive_types(actual):
236+
return handle_recursive_union(template, actual, direction)
237+
return []
226238

227239
# Remaining cases are handled by ConstraintBuilderVisitor.
228240
return template.accept(ConstraintBuilderVisitor(actual, direction))
@@ -279,6 +291,19 @@ def merge_with_any(constraint: Constraint) -> Constraint:
279291
)
280292

281293

294+
def handle_recursive_union(template: UnionType, actual: Type, direction: int) -> List[Constraint]:
295+
# This is a hack to special-case things like Union[T, Inst[T]] in recursive types. Although
296+
# it is quite arbitrary, it is a relatively common pattern, so we should handle it well.
297+
# This function may be called when inferring against such union resulted in different
298+
# constraints for each item. Normally we give up in such case, but here we instead split
299+
# the union in two parts, and try inferring sequentially.
300+
non_type_var_items = [t for t in template.items if not isinstance(t, TypeVarType)]
301+
type_var_items = [t for t in template.items if isinstance(t, TypeVarType)]
302+
return infer_constraints(
303+
UnionType.make_union(non_type_var_items), actual, direction
304+
) or infer_constraints(UnionType.make_union(type_var_items), actual, direction)
305+
306+
282307
def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> List[Constraint]:
283308
"""Deduce what we can from a collection of constraint lists.
284309

mypy/infer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def infer_function_type_arguments(
3434
formal_to_actual: List[List[int]],
3535
context: ArgumentInferContext,
3636
strict: bool = True,
37-
infer_unions: bool = False,
3837
) -> List[Optional[Type]]:
3938
"""Infer the type arguments of a generic function.
4039
@@ -56,7 +55,7 @@ def infer_function_type_arguments(
5655

5756
# Solve constraints.
5857
type_vars = callee_type.type_var_ids()
59-
return solve_constraints(type_vars, constraints, strict, infer_unions=infer_unions)
58+
return solve_constraints(type_vars, constraints, strict)
6059

6160

6261
def infer_type_arguments(

mypy/solve.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717
UnionType,
1818
get_proper_type,
1919
)
20+
from mypy.typestate import TypeState
2021

2122

2223
def solve_constraints(
23-
vars: List[TypeVarId],
24-
constraints: List[Constraint],
25-
strict: bool = True,
26-
infer_unions: bool = False,
24+
vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True
2725
) -> List[Optional[Type]]:
2826
"""Solve type constraints.
2927
@@ -55,7 +53,7 @@ def solve_constraints(
5553
if bottom is None:
5654
bottom = c.target
5755
else:
58-
if infer_unions:
56+
if TypeState.infer_unions:
5957
# This deviates from the general mypy semantics because
6058
# recursive types are union-heavy in 95% of cases.
6159
bottom = UnionType.make_union([bottom, c.target])

mypy/subtypes.py

+3-15
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,7 @@ def is_subtype(
145145
), "Don't pass both context and individual flags"
146146
if TypeState.is_assumed_subtype(left, right):
147147
return True
148-
if (
149-
# TODO: recursive instances like `class str(Sequence[str])` can also cause
150-
# issues, so we also need to include them in the assumptions stack
151-
isinstance(left, TypeAliasType)
152-
and isinstance(right, TypeAliasType)
153-
and left.is_recursive
154-
and right.is_recursive
155-
):
148+
if mypy.typeops.is_recursive_pair(left, right):
156149
# This case requires special care because it may cause infinite recursion.
157150
# Our view on recursive types is known under a fancy name of iso-recursive mu-types.
158151
# Roughly this means that a recursive type is defined as an alias where right hand side
@@ -205,12 +198,7 @@ def is_proper_subtype(
205198
), "Don't pass both context and individual flags"
206199
if TypeState.is_assumed_proper_subtype(left, right):
207200
return True
208-
if (
209-
isinstance(left, TypeAliasType)
210-
and isinstance(right, TypeAliasType)
211-
and left.is_recursive
212-
and right.is_recursive
213-
):
201+
if mypy.typeops.is_recursive_pair(left, right):
214202
# Same as for non-proper subtype, see detailed comment there for explanation.
215203
with pop_on_exit(TypeState.get_assumptions(is_proper=True), left, right):
216204
return _is_subtype(left, right, subtype_context, proper_subtype=True)
@@ -874,7 +862,7 @@ def visit_type_alias_type(self, left: TypeAliasType) -> bool:
874862
assert False, f"This should be never called, got {left}"
875863

876864

877-
T = TypeVar("T", Instance, TypeAliasType)
865+
T = TypeVar("T", bound=Type)
878866

879867

880868
@contextmanager

mypy/typeops.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,25 @@
6363

6464

6565
def is_recursive_pair(s: Type, t: Type) -> bool:
66-
"""Is this a pair of recursive type aliases?"""
67-
return (
68-
isinstance(s, TypeAliasType)
69-
and isinstance(t, TypeAliasType)
70-
and s.is_recursive
71-
and t.is_recursive
72-
)
66+
"""Is this a pair of recursive types?
67+
68+
There may be more cases, and we may be forced to use e.g. has_recursive_types()
69+
here, but this function is called in very hot code, so we try to keep it simple
70+
and return True only in cases we know may have problems.
71+
"""
72+
if isinstance(s, TypeAliasType) and s.is_recursive:
73+
return (
74+
isinstance(get_proper_type(t), Instance)
75+
or isinstance(t, TypeAliasType)
76+
and t.is_recursive
77+
)
78+
if isinstance(t, TypeAliasType) and t.is_recursive:
79+
return (
80+
isinstance(get_proper_type(s), Instance)
81+
or isinstance(s, TypeAliasType)
82+
and s.is_recursive
83+
)
84+
return False
7385

7486

7587
def tuple_fallback(typ: TupleType) -> Instance:
@@ -81,9 +93,8 @@ def tuple_fallback(typ: TupleType) -> Instance:
8193
return typ.partial_fallback
8294
items = []
8395
for item in typ.items:
84-
proper_type = get_proper_type(item)
85-
if isinstance(proper_type, UnpackType):
86-
unpacked_type = get_proper_type(proper_type.type)
96+
if isinstance(item, UnpackType):
97+
unpacked_type = get_proper_type(item.type)
8798
if isinstance(unpacked_type, TypeVarTupleType):
8899
items.append(unpacked_type.upper_bound)
89100
elif isinstance(unpacked_type, TupleType):

mypy/typestate.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from mypy.nodes import TypeInfo
1111
from mypy.server.trigger import make_trigger
12-
from mypy.types import Instance, Type, TypeAliasType, get_proper_type
12+
from mypy.types import Instance, Type, get_proper_type
1313

1414
# Represents that the 'left' instance is a subtype of the 'right' instance
1515
SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance]
@@ -80,10 +80,12 @@ class TypeState:
8080
# recursive type aliases. Normally, one would pass type assumptions as an additional
8181
# arguments to is_subtype(), but this would mean updating dozens of related functions
8282
# threading this through all callsites (see also comment for TypeInfo.assuming).
83-
_assuming: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = []
84-
_assuming_proper: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = []
83+
_assuming: Final[List[Tuple[Type, Type]]] = []
84+
_assuming_proper: Final[List[Tuple[Type, Type]]] = []
8585
# Ditto for inference of generic constraints against recursive type aliases.
86-
_inferring: Final[List[TypeAliasType]] = []
86+
inferring: Final[List[Tuple[Type, Type]]] = []
87+
# Whether to use joins or unions when solving constraints, see checkexpr.py for details.
88+
infer_unions: ClassVar = False
8789

8890
# N.B: We do all of the accesses to these properties through
8991
# TypeState, instead of making these classmethods and accessing
@@ -109,7 +111,7 @@ def is_assumed_proper_subtype(left: Type, right: Type) -> bool:
109111
return False
110112

111113
@staticmethod
112-
def get_assumptions(is_proper: bool) -> List[Tuple[TypeAliasType, TypeAliasType]]:
114+
def get_assumptions(is_proper: bool) -> List[Tuple[Type, Type]]:
113115
if is_proper:
114116
return TypeState._assuming_proper
115117
return TypeState._assuming

0 commit comments

Comments
 (0)