Skip to content

Commit 8c14cba

Browse files
JukkaLAlexWaygood
andauthored
Propagate type narrowing to nested functions (#15133)
Fixes #2608. Use the heuristic suggested in #2608 and allow narrowed types of variables (but not attributes) to be propagated to nested functions if the variable is not assigned to after the definition of the nested function in the outer function. Since we don't have a full control flow graph, we simply look for assignments that are textually after the nested function in the outer function. This can result in false negatives (at least in loops) and false positives (in if statements, and if the assigned type is narrow enough), but I expect these to be rare and not a significant issue. Type narrowing is already unsound, and the additional unsoundness seems minor, while the usability benefit is big. This doesn't do the right thing for nested classes yet. I'll create an issue to track that. --------- Co-authored-by: Alex Waygood <[email protected]>
1 parent d71ece8 commit 8c14cba

File tree

8 files changed

+518
-5
lines changed

8 files changed

+518
-5
lines changed

mypy/binder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def __init__(self, id: int, conditional_frame: bool = False) -> None:
5151
# need this field.
5252
self.suppress_unreachable_warnings = False
5353

54+
def __repr__(self) -> str:
55+
return f"Frame({self.id}, {self.types}, {self.unreachable}, {self.conditional_frame})"
56+
5457

5558
Assigns = DefaultDict[Expression, List[Tuple[Type, Optional[Type]]]]
5659

@@ -63,7 +66,7 @@ class ConditionalTypeBinder:
6366
6467
```
6568
class A:
66-
a = None # type: Union[int, str]
69+
a: Union[int, str] = None
6770
x = A()
6871
lst = [x]
6972
reveal_type(x.a) # Union[int, str]
@@ -446,6 +449,7 @@ def top_frame_context(self) -> Iterator[Frame]:
446449
assert len(self.frames) == 1
447450
yield self.push_frame()
448451
self.pop_frame(True, 0)
452+
assert len(self.frames) == 1
449453

450454

451455
def get_declaration(expr: BindableExpression) -> Type | None:

mypy/checker.py

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import mypy.checkexpr
2828
from mypy import errorcodes as codes, message_registry, nodes, operators
29-
from mypy.binder import ConditionalTypeBinder, get_declaration
29+
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
3030
from mypy.checkmember import (
3131
MemberContext,
3232
analyze_decorator_or_funcbase_access,
@@ -41,7 +41,7 @@
4141
from mypy.errors import Errors, ErrorWatcher, report_internal_error
4242
from mypy.expandtype import expand_self_type, expand_type, expand_type_by_instance
4343
from mypy.join import join_types
44-
from mypy.literals import Key, literal, literal_hash
44+
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
4545
from mypy.maptype import map_instance_to_supertype
4646
from mypy.meet import is_overlapping_erased_types, is_overlapping_types
4747
from mypy.message_registry import ErrorMessage
@@ -134,6 +134,7 @@
134134
is_final_node,
135135
)
136136
from mypy.options import Options
137+
from mypy.patterns import AsPattern, StarredPattern
137138
from mypy.plugin import CheckerPluginInterface, Plugin
138139
from mypy.scope import Scope
139140
from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name
@@ -151,7 +152,7 @@
151152
restrict_subtype_away,
152153
unify_generic_callable,
153154
)
154-
from mypy.traverser import all_return_statements, has_return_statement
155+
from mypy.traverser import TraverserVisitor, all_return_statements, has_return_statement
155156
from mypy.treetransform import TransformVisitor
156157
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type
157158
from mypy.typeops import (
@@ -1207,6 +1208,20 @@ def check_func_def(
12071208

12081209
# Type check body in a new scope.
12091210
with self.binder.top_frame_context():
1211+
# Copy some type narrowings from an outer function when it seems safe enough
1212+
# (i.e. we can't find an assignment that might change the type of the
1213+
# variable afterwards).
1214+
new_frame: Frame | None = None
1215+
for frame in old_binder.frames:
1216+
for key, narrowed_type in frame.types.items():
1217+
key_var = extract_var_from_literal_hash(key)
1218+
if key_var is not None and not self.is_var_redefined_in_outer_context(
1219+
key_var, defn.line
1220+
):
1221+
# It seems safe to propagate the type narrowing to a nested scope.
1222+
if new_frame is None:
1223+
new_frame = self.binder.push_frame()
1224+
new_frame.types[key] = narrowed_type
12101225
with self.scope.push_function(defn):
12111226
# We suppress reachability warnings when we use TypeVars with value
12121227
# restrictions: we only want to report a warning if a certain statement is
@@ -1218,6 +1233,8 @@ def check_func_def(
12181233
self.binder.suppress_unreachable_warnings()
12191234
self.accept(item.body)
12201235
unreachable = self.binder.is_unreachable()
1236+
if new_frame is not None:
1237+
self.binder.pop_frame(True, 0)
12211238

12221239
if not unreachable:
12231240
if defn.is_generator or is_named_instance(
@@ -1310,6 +1327,23 @@ def check_func_def(
13101327

13111328
self.binder = old_binder
13121329

1330+
def is_var_redefined_in_outer_context(self, v: Var, after_line: int) -> bool:
1331+
"""Can the variable be assigned to at module top level or outer function?
1332+
1333+
Note that this doesn't do a full CFG analysis but uses a line number based
1334+
heuristic that isn't correct in some (rare) cases.
1335+
"""
1336+
outers = self.tscope.outer_functions()
1337+
if not outers:
1338+
# Top-level function -- outer context is top level, and we can't reason about
1339+
# globals
1340+
return True
1341+
for outer in outers:
1342+
if isinstance(outer, FuncDef):
1343+
if find_last_var_assignment_line(outer.body, v) >= after_line:
1344+
return True
1345+
return False
1346+
13131347
def check_unbound_return_typevar(self, typ: CallableType) -> None:
13141348
"""Fails when the return typevar is not defined in arguments."""
13151349
if isinstance(typ.ret_type, TypeVarType) and typ.ret_type in typ.variables:
@@ -7629,3 +7663,80 @@ def collapse_walrus(e: Expression) -> Expression:
76297663
if isinstance(e, AssignmentExpr):
76307664
return e.target
76317665
return e
7666+
7667+
7668+
def find_last_var_assignment_line(n: Node, v: Var) -> int:
7669+
"""Find the highest line number of a potential assignment to variable within node.
7670+
7671+
This supports local and global variables.
7672+
7673+
Return -1 if no assignment was found.
7674+
"""
7675+
visitor = VarAssignVisitor(v)
7676+
n.accept(visitor)
7677+
return visitor.last_line
7678+
7679+
7680+
class VarAssignVisitor(TraverserVisitor):
7681+
def __init__(self, v: Var) -> None:
7682+
self.last_line = -1
7683+
self.lvalue = False
7684+
self.var_node = v
7685+
7686+
def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
7687+
self.lvalue = True
7688+
for lv in s.lvalues:
7689+
lv.accept(self)
7690+
self.lvalue = False
7691+
7692+
def visit_name_expr(self, e: NameExpr) -> None:
7693+
if self.lvalue and e.node is self.var_node:
7694+
self.last_line = max(self.last_line, e.line)
7695+
7696+
def visit_member_expr(self, e: MemberExpr) -> None:
7697+
old_lvalue = self.lvalue
7698+
self.lvalue = False
7699+
super().visit_member_expr(e)
7700+
self.lvalue = old_lvalue
7701+
7702+
def visit_index_expr(self, e: IndexExpr) -> None:
7703+
old_lvalue = self.lvalue
7704+
self.lvalue = False
7705+
super().visit_index_expr(e)
7706+
self.lvalue = old_lvalue
7707+
7708+
def visit_with_stmt(self, s: WithStmt) -> None:
7709+
self.lvalue = True
7710+
for lv in s.target:
7711+
if lv is not None:
7712+
lv.accept(self)
7713+
self.lvalue = False
7714+
s.body.accept(self)
7715+
7716+
def visit_for_stmt(self, s: ForStmt) -> None:
7717+
self.lvalue = True
7718+
s.index.accept(self)
7719+
self.lvalue = False
7720+
s.body.accept(self)
7721+
if s.else_body:
7722+
s.else_body.accept(self)
7723+
7724+
def visit_assignment_expr(self, e: AssignmentExpr) -> None:
7725+
self.lvalue = True
7726+
e.target.accept(self)
7727+
self.lvalue = False
7728+
e.value.accept(self)
7729+
7730+
def visit_as_pattern(self, p: AsPattern) -> None:
7731+
if p.pattern is not None:
7732+
p.pattern.accept(self)
7733+
if p.name is not None:
7734+
self.lvalue = True
7735+
p.name.accept(self)
7736+
self.lvalue = False
7737+
7738+
def visit_starred_pattern(self, p: StarredPattern) -> None:
7739+
if p.capture is not None:
7740+
self.lvalue = True
7741+
p.capture.accept(self)
7742+
self.lvalue = False

mypy/fastparse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1766,7 +1766,8 @@ def visit_MatchStar(self, n: MatchStar) -> StarredPattern:
17661766
if n.name is None:
17671767
node = StarredPattern(None)
17681768
else:
1769-
node = StarredPattern(NameExpr(n.name))
1769+
name = self.set_line(NameExpr(n.name), n)
1770+
node = StarredPattern(name)
17701771

17711772
return self.set_line(node, n)
17721773

mypy/literals.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ def literal_hash(e: Expression) -> Key | None:
139139
return e.accept(_hasher)
140140

141141

142+
def extract_var_from_literal_hash(key: Key) -> Var | None:
143+
"""If key refers to a Var node, return it.
144+
145+
Return None otherwise.
146+
"""
147+
if len(key) == 2 and key[0] == "Var" and isinstance(key[1], Var):
148+
return key[1]
149+
return None
150+
151+
142152
class _Hasher(ExpressionVisitor[Optional[Key]]):
143153
def visit_int_expr(self, e: IntExpr) -> Key:
144154
return ("Literal", e.value)

mypy/scope.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self) -> None:
2121
self.module: str | None = None
2222
self.classes: list[TypeInfo] = []
2323
self.function: FuncBase | None = None
24+
self.functions: list[FuncBase] = []
2425
# Number of nested scopes ignored (that don't get their own separate targets)
2526
self.ignored = 0
2627

@@ -65,19 +66,24 @@ def module_scope(self, prefix: str) -> Iterator[None]:
6566

6667
@contextmanager
6768
def function_scope(self, fdef: FuncBase) -> Iterator[None]:
69+
self.functions.append(fdef)
6870
if not self.function:
6971
self.function = fdef
7072
else:
7173
# Nested functions are part of the topmost function target.
7274
self.ignored += 1
7375
yield
76+
self.functions.pop()
7477
if self.ignored:
7578
# Leave a scope that's included in the enclosing target.
7679
self.ignored -= 1
7780
else:
7881
assert self.function
7982
self.function = None
8083

84+
def outer_functions(self) -> list[FuncBase]:
85+
return self.functions[:-1]
86+
8187
def enter_class(self, info: TypeInfo) -> None:
8288
"""Enter a class target scope."""
8389
if not self.function:

0 commit comments

Comments
 (0)