26
26
27
27
import mypy .checkexpr
28
28
from mypy import errorcodes as codes , message_registry , nodes , operators
29
- from mypy .binder import ConditionalTypeBinder , get_declaration
29
+ from mypy .binder import ConditionalTypeBinder , Frame , get_declaration
30
30
from mypy .checkmember import (
31
31
MemberContext ,
32
32
analyze_decorator_or_funcbase_access ,
41
41
from mypy .errors import Errors , ErrorWatcher , report_internal_error
42
42
from mypy .expandtype import expand_self_type , expand_type , expand_type_by_instance
43
43
from mypy .join import join_types
44
- from mypy .literals import Key , literal , literal_hash
44
+ from mypy .literals import Key , extract_var_from_literal_hash , literal , literal_hash
45
45
from mypy .maptype import map_instance_to_supertype
46
46
from mypy .meet import is_overlapping_erased_types , is_overlapping_types
47
47
from mypy .message_registry import ErrorMessage
134
134
is_final_node ,
135
135
)
136
136
from mypy .options import Options
137
+ from mypy .patterns import AsPattern , StarredPattern
137
138
from mypy .plugin import CheckerPluginInterface , Plugin
138
139
from mypy .scope import Scope
139
140
from mypy .semanal import is_trivial_body , refers_to_fullname , set_callable_name
151
152
restrict_subtype_away ,
152
153
unify_generic_callable ,
153
154
)
154
- from mypy .traverser import all_return_statements , has_return_statement
155
+ from mypy .traverser import TraverserVisitor , all_return_statements , has_return_statement
155
156
from mypy .treetransform import TransformVisitor
156
157
from mypy .typeanal import check_for_explicit_any , has_any_from_unimported_type , make_optional_type
157
158
from mypy .typeops import (
@@ -1207,6 +1208,20 @@ def check_func_def(
1207
1208
1208
1209
# Type check body in a new scope.
1209
1210
with self .binder .top_frame_context ():
1211
+ # Copy some type narrowings from an outer function when it seems safe enough
1212
+ # (i.e. we can't find an assignment that might change the type of the
1213
+ # variable afterwards).
1214
+ new_frame : Frame | None = None
1215
+ for frame in old_binder .frames :
1216
+ for key , narrowed_type in frame .types .items ():
1217
+ key_var = extract_var_from_literal_hash (key )
1218
+ if key_var is not None and not self .is_var_redefined_in_outer_context (
1219
+ key_var , defn .line
1220
+ ):
1221
+ # It seems safe to propagate the type narrowing to a nested scope.
1222
+ if new_frame is None :
1223
+ new_frame = self .binder .push_frame ()
1224
+ new_frame .types [key ] = narrowed_type
1210
1225
with self .scope .push_function (defn ):
1211
1226
# We suppress reachability warnings when we use TypeVars with value
1212
1227
# restrictions: we only want to report a warning if a certain statement is
@@ -1218,6 +1233,8 @@ def check_func_def(
1218
1233
self .binder .suppress_unreachable_warnings ()
1219
1234
self .accept (item .body )
1220
1235
unreachable = self .binder .is_unreachable ()
1236
+ if new_frame is not None :
1237
+ self .binder .pop_frame (True , 0 )
1221
1238
1222
1239
if not unreachable :
1223
1240
if defn .is_generator or is_named_instance (
@@ -1310,6 +1327,23 @@ def check_func_def(
1310
1327
1311
1328
self .binder = old_binder
1312
1329
1330
+ def is_var_redefined_in_outer_context (self , v : Var , after_line : int ) -> bool :
1331
+ """Can the variable be assigned to at module top level or outer function?
1332
+
1333
+ Note that this doesn't do a full CFG analysis but uses a line number based
1334
+ heuristic that isn't correct in some (rare) cases.
1335
+ """
1336
+ outers = self .tscope .outer_functions ()
1337
+ if not outers :
1338
+ # Top-level function -- outer context is top level, and we can't reason about
1339
+ # globals
1340
+ return True
1341
+ for outer in outers :
1342
+ if isinstance (outer , FuncDef ):
1343
+ if find_last_var_assignment_line (outer .body , v ) >= after_line :
1344
+ return True
1345
+ return False
1346
+
1313
1347
def check_unbound_return_typevar (self , typ : CallableType ) -> None :
1314
1348
"""Fails when the return typevar is not defined in arguments."""
1315
1349
if isinstance (typ .ret_type , TypeVarType ) and typ .ret_type in typ .variables :
@@ -7629,3 +7663,80 @@ def collapse_walrus(e: Expression) -> Expression:
7629
7663
if isinstance (e , AssignmentExpr ):
7630
7664
return e .target
7631
7665
return e
7666
+
7667
+
7668
+ def find_last_var_assignment_line (n : Node , v : Var ) -> int :
7669
+ """Find the highest line number of a potential assignment to variable within node.
7670
+
7671
+ This supports local and global variables.
7672
+
7673
+ Return -1 if no assignment was found.
7674
+ """
7675
+ visitor = VarAssignVisitor (v )
7676
+ n .accept (visitor )
7677
+ return visitor .last_line
7678
+
7679
+
7680
+ class VarAssignVisitor (TraverserVisitor ):
7681
+ def __init__ (self , v : Var ) -> None :
7682
+ self .last_line = - 1
7683
+ self .lvalue = False
7684
+ self .var_node = v
7685
+
7686
+ def visit_assignment_stmt (self , s : AssignmentStmt ) -> None :
7687
+ self .lvalue = True
7688
+ for lv in s .lvalues :
7689
+ lv .accept (self )
7690
+ self .lvalue = False
7691
+
7692
+ def visit_name_expr (self , e : NameExpr ) -> None :
7693
+ if self .lvalue and e .node is self .var_node :
7694
+ self .last_line = max (self .last_line , e .line )
7695
+
7696
+ def visit_member_expr (self , e : MemberExpr ) -> None :
7697
+ old_lvalue = self .lvalue
7698
+ self .lvalue = False
7699
+ super ().visit_member_expr (e )
7700
+ self .lvalue = old_lvalue
7701
+
7702
+ def visit_index_expr (self , e : IndexExpr ) -> None :
7703
+ old_lvalue = self .lvalue
7704
+ self .lvalue = False
7705
+ super ().visit_index_expr (e )
7706
+ self .lvalue = old_lvalue
7707
+
7708
+ def visit_with_stmt (self , s : WithStmt ) -> None :
7709
+ self .lvalue = True
7710
+ for lv in s .target :
7711
+ if lv is not None :
7712
+ lv .accept (self )
7713
+ self .lvalue = False
7714
+ s .body .accept (self )
7715
+
7716
+ def visit_for_stmt (self , s : ForStmt ) -> None :
7717
+ self .lvalue = True
7718
+ s .index .accept (self )
7719
+ self .lvalue = False
7720
+ s .body .accept (self )
7721
+ if s .else_body :
7722
+ s .else_body .accept (self )
7723
+
7724
+ def visit_assignment_expr (self , e : AssignmentExpr ) -> None :
7725
+ self .lvalue = True
7726
+ e .target .accept (self )
7727
+ self .lvalue = False
7728
+ e .value .accept (self )
7729
+
7730
+ def visit_as_pattern (self , p : AsPattern ) -> None :
7731
+ if p .pattern is not None :
7732
+ p .pattern .accept (self )
7733
+ if p .name is not None :
7734
+ self .lvalue = True
7735
+ p .name .accept (self )
7736
+ self .lvalue = False
7737
+
7738
+ def visit_starred_pattern (self , p : StarredPattern ) -> None :
7739
+ if p .capture is not None :
7740
+ self .lvalue = True
7741
+ p .capture .accept (self )
7742
+ self .lvalue = False
0 commit comments