Skip to content

Commit 7b6fbb7

Browse files
elazarggvanrossum
authored andcommitted
assert/remove casts in semanal.py (#2341)
1 parent ba85545 commit 7b6fbb7

File tree

1 file changed

+58
-57
lines changed

1 file changed

+58
-57
lines changed

mypy/semanal.py

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"""
4545

4646
from typing import (
47-
List, Dict, Set, Tuple, cast, Any, TypeVar, Union, Optional, Callable
47+
List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable
4848
)
4949

5050
from mypy.nodes import (
@@ -280,9 +280,7 @@ def visit_func_def(self, defn: FuncDef) -> None:
280280
if defn.name() in self.type.names:
281281
# Redefinition. Conditional redefinition is okay.
282282
n = self.type.names[defn.name()].node
283-
if self.is_conditional_func(n, defn):
284-
defn.original_def = cast(FuncDef, n)
285-
else:
283+
if not self.set_original_def(n, defn):
286284
self.name_already_defined(defn.name(), defn)
287285
self.type.names[defn.name()] = SymbolTableNode(MDEF, defn)
288286
self.prepare_method_signature(defn)
@@ -292,9 +290,7 @@ def visit_func_def(self, defn: FuncDef) -> None:
292290
if defn.name() in self.locals[-1]:
293291
# Redefinition. Conditional redefinition is okay.
294292
n = self.locals[-1][defn.name()].node
295-
if self.is_conditional_func(n, defn):
296-
defn.original_def = cast(FuncDef, n)
297-
else:
293+
if not self.set_original_def(n, defn):
298294
self.name_already_defined(defn.name(), defn)
299295
else:
300296
self.add_local(defn, defn)
@@ -304,11 +300,7 @@ def visit_func_def(self, defn: FuncDef) -> None:
304300
symbol = self.globals.get(defn.name())
305301
if isinstance(symbol.node, FuncDef) and symbol.node != defn:
306302
# This is redefinition. Conditional redefinition is okay.
307-
original_def = symbol.node
308-
if self.is_conditional_func(original_def, defn):
309-
# Conditional function definition -- multiple defs are ok.
310-
defn.original_def = original_def
311-
else:
303+
if not self.set_original_def(symbol.node, defn):
312304
# Report error.
313305
self.check_no_global(defn.name(), defn, True)
314306
if phase_info == FUNCTION_FIRST_PHASE_POSTPONE_SECOND:
@@ -341,19 +333,22 @@ def prepare_method_signature(self, func: FuncDef) -> None:
341333
leading_type = self.class_type(self.type)
342334
else:
343335
leading_type = fill_typevars(self.type)
344-
sig = cast(FunctionLike, func.type)
345-
func.type = replace_implicit_first_type(sig, leading_type)
336+
func.type = replace_implicit_first_type(functype, leading_type)
346337

347-
def is_conditional_func(self, previous: Node, new: FuncDef) -> bool:
348-
"""Does 'new' conditionally redefine 'previous'?
338+
def set_original_def(self, previous: Node, new: FuncDef) -> bool:
339+
"""If 'new' conditionally redefine 'previous', set 'previous' as original
349340
350341
We reject straight redefinitions of functions, as they are usually
351342
a programming error. For example:
352343
353344
. def f(): ...
354345
. def f(): ... # Error: 'f' redefined
355346
"""
356-
return isinstance(previous, (FuncDef, Var)) and new.is_conditional
347+
if isinstance(previous, (FuncDef, Var)) and new.is_conditional:
348+
new.original_def = previous
349+
return True
350+
else:
351+
return False
357352

358353
def update_function_type_variables(self, defn: FuncDef) -> None:
359354
"""Make any type variables in the signature of defn explicit.
@@ -362,8 +357,8 @@ def update_function_type_variables(self, defn: FuncDef) -> None:
362357
if defn is generic.
363358
"""
364359
if defn.type:
365-
functype = cast(CallableType, defn.type)
366-
typevars = self.infer_type_variables(functype)
360+
assert isinstance(defn.type, CallableType)
361+
typevars = self.infer_type_variables(defn.type)
367362
# Do not define a new type variable if already defined in scope.
368363
typevars = [(name, tvar) for name, tvar in typevars
369364
if not self.is_defined_type_var(name, defn)]
@@ -373,7 +368,7 @@ def update_function_type_variables(self, defn: FuncDef) -> None:
373368
tvar[1].values, tvar[1].upper_bound,
374369
tvar[1].variance)
375370
for i, tvar in enumerate(typevars)]
376-
functype.variables = defs
371+
defn.type.variables = defs
377372

378373
def infer_type_variables(self,
379374
type: CallableType) -> List[Tuple[str, TypeVarExpr]]:
@@ -387,8 +382,7 @@ def infer_type_variables(self,
387382
tvars.append(tvar_expr)
388383
return list(zip(names, tvars))
389384

390-
def find_type_variables_in_type(
391-
self, type: Type) -> List[Tuple[str, TypeVarExpr]]:
385+
def find_type_variables_in_type(self, type: Type) -> List[Tuple[str, TypeVarExpr]]:
392386
"""Return a list of all unique type variable references in type.
393387
394388
This effectively does partial name binding, results of which are mostly thrown away.
@@ -398,7 +392,8 @@ def find_type_variables_in_type(
398392
name = type.name
399393
node = self.lookup_qualified(name, type)
400394
if node and node.kind == UNBOUND_TVAR:
401-
result.append((name, cast(TypeVarExpr, node.node)))
395+
assert isinstance(node.node, TypeVarExpr)
396+
result.append((name, node.node))
402397
for arg in type.args:
403398
result.extend(self.find_type_variables_in_type(arg))
404399
elif isinstance(type, TypeList):
@@ -425,8 +420,9 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
425420
item.is_overload = True
426421
item.func.is_overload = True
427422
item.accept(self)
428-
t.append(cast(CallableType, function_type(item.func,
429-
self.builtin_type('builtins.function'))))
423+
callable = function_type(item.func, self.builtin_type('builtins.function'))
424+
assert isinstance(callable, CallableType)
425+
t.append(callable)
430426
if item.func.is_property and i == 0:
431427
# This defines a property, probably with a setter and/or deleter.
432428
self.analyze_property_with_multi_part_definition(defn)
@@ -524,8 +520,9 @@ def add_func_type_variables_to_symbol_table(
524520
nodes = [] # type: List[SymbolTableNode]
525521
if defn.type:
526522
tt = defn.type
523+
assert isinstance(tt, CallableType)
524+
items = tt.variables
527525
names = self.type_var_names()
528-
items = cast(CallableType, tt).variables
529526
for item in items:
530527
name = item.name
531528
if name in names:
@@ -549,7 +546,8 @@ def bind_type_var(self, fullname: str, tvar_def: TypeVarDef,
549546
return node
550547

551548
def check_function_signature(self, fdef: FuncItem) -> None:
552-
sig = cast(CallableType, fdef.type)
549+
sig = fdef.type
550+
assert isinstance(sig, CallableType)
553551
if len(sig.arg_types) < len(fdef.arguments):
554552
self.fail('Type signature has too few arguments', fdef)
555553
# Add dummy Any arguments to prevent crashes later.
@@ -725,7 +723,8 @@ def analyze_unbound_tvar(self, t: Type) -> Tuple[str, TypeVarExpr]:
725723
unbound = t
726724
sym = self.lookup_qualified(unbound.name, unbound)
727725
if sym is not None and sym.kind == UNBOUND_TVAR:
728-
return unbound.name, cast(TypeVarExpr, sym.node)
726+
assert isinstance(sym.node, TypeVarExpr)
727+
return unbound.name, sym.node
729728
return None
730729

731730
def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
@@ -922,13 +921,15 @@ def class_type(self, info: TypeInfo) -> Type:
922921

923922
def named_type(self, qualified_name: str, args: List[Type] = None) -> Instance:
924923
sym = self.lookup_qualified(qualified_name, None)
925-
return Instance(cast(TypeInfo, sym.node), args or [])
924+
assert isinstance(sym.node, TypeInfo)
925+
return Instance(sym.node, args or [])
926926

927927
def named_type_or_none(self, qualified_name: str, args: List[Type] = None) -> Instance:
928928
sym = self.lookup_fully_qualified_or_none(qualified_name)
929929
if not sym:
930930
return None
931-
return Instance(cast(TypeInfo, sym.node), args or [])
931+
assert isinstance(sym.node, TypeInfo)
932+
return Instance(sym.node, args or [])
932933

933934
def bind_class_type_variables_in_symbol_table(
934935
self, info: TypeInfo) -> List[SymbolTableNode]:
@@ -1300,11 +1301,10 @@ def analyze_lvalue(self, lval: Lvalue, nested: bool = False,
13001301
lval.accept(self)
13011302
elif (isinstance(lval, TupleExpr) or
13021303
isinstance(lval, ListExpr)):
1303-
items = cast(Any, lval).items
1304+
items = lval.items
13041305
if len(items) == 0 and isinstance(lval, TupleExpr):
13051306
self.fail("Can't assign to ()", lval)
1306-
self.analyze_tuple_or_list_lvalue(cast(Union[ListExpr, TupleExpr], lval),
1307-
add_global, explicit_type)
1307+
self.analyze_tuple_or_list_lvalue(lval, add_global, explicit_type)
13081308
elif isinstance(lval, StarExpr):
13091309
if nested:
13101310
self.analyze_lvalue(lval.expr, nested, add_global, explicit_type)
@@ -1318,9 +1318,7 @@ def analyze_tuple_or_list_lvalue(self, lval: Union[ListExpr, TupleExpr],
13181318
explicit_type: bool = False) -> None:
13191319
"""Analyze an lvalue or assignment target that is a list or tuple."""
13201320
items = lval.items
1321-
star_exprs = [cast(StarExpr, item)
1322-
for item in items
1323-
if isinstance(item, StarExpr)]
1321+
star_exprs = [item for item in items if isinstance(item, StarExpr)]
13241322

13251323
if len(star_exprs) > 1:
13261324
self.fail('Two starred expressions in assignment', lval)
@@ -1452,14 +1450,14 @@ def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Opt
14521450
if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)):
14531451
self.fail("Argument 1 to NewType(...) must be a string literal", context)
14541452
has_failed = True
1455-
elif cast(StrExpr, call.args[0]).value != name:
1453+
elif args[0].value != name:
14561454
msg = "String argument 1 '{}' to NewType(...) does not match variable name '{}'"
1457-
self.fail(msg.format(cast(StrExpr, call.args[0]).value, name), context)
1455+
self.fail(msg.format(args[0].value, name), context)
14581456
has_failed = True
14591457

14601458
# Check second argument
14611459
try:
1462-
unanalyzed_type = expr_to_unanalyzed_type(call.args[1])
1460+
unanalyzed_type = expr_to_unanalyzed_type(args[1])
14631461
except TypeTranslationError:
14641462
self.fail("Argument 2 to NewType(...) must be a valid type", context)
14651463
return None
@@ -1497,7 +1495,8 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> None:
14971495
if not call:
14981496
return
14991497

1500-
lvalue = cast(NameExpr, s.lvalues[0])
1498+
lvalue = s.lvalues[0]
1499+
assert isinstance(lvalue, NameExpr)
15011500
name = lvalue.name
15021501
if not lvalue.is_def:
15031502
if s.type:
@@ -1538,9 +1537,9 @@ def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> boo
15381537
or not call.arg_kinds[0] == ARG_POS):
15391538
self.fail("TypeVar() expects a string literal as first argument", context)
15401539
return False
1541-
if cast(StrExpr, call.args[0]).value != name:
1540+
elif call.args[0].value != name:
15421541
msg = "String argument 1 '{}' to TypeVar(...) does not match variable name '{}'"
1543-
self.fail(msg.format(cast(StrExpr, call.args[0]).value, name), context)
1542+
self.fail(msg.format(call.args[0].value, name), context)
15441543
return False
15451544
return True
15461545

@@ -2308,7 +2307,8 @@ def visit_member_expr(self, expr: MemberExpr) -> None:
23082307
# This branch handles the case foo.bar where foo is a module.
23092308
# In this case base.node is the module's MypyFile and we look up
23102309
# bar in its namespace. This must be done for all types of bar.
2311-
file = cast(MypyFile, base.node)
2310+
file = base.node
2311+
assert isinstance(file, MypyFile)
23122312
n = file.names.get(expr.name, None) if file is not None else None
23132313
if n:
23142314
n = self.normalize_type_alias(n, expr)
@@ -2513,7 +2513,8 @@ def lookup(self, name: str, ctx: Context) -> SymbolTableNode:
25132513
# 5. Builtins
25142514
b = self.globals.get('__builtins__', None)
25152515
if b:
2516-
table = cast(MypyFile, b.node).names
2516+
assert isinstance(b.node, MypyFile)
2517+
table = b.node.names
25172518
if name in table:
25182519
if name[0] == "_" and name[1] != "_":
25192520
self.name_not_defined(name, ctx)
@@ -2568,8 +2569,8 @@ def lookup_qualified(self, name: str, ctx: Context) -> SymbolTableNode:
25682569

25692570
def builtin_type(self, fully_qualified_name: str) -> Instance:
25702571
node = self.lookup_fully_qualified(fully_qualified_name)
2571-
info = cast(TypeInfo, node.node)
2572-
return Instance(info, [])
2572+
assert isinstance(node.node, TypeInfo)
2573+
return Instance(node.node, [])
25732574

25742575
def lookup_fully_qualified(self, name: str) -> SymbolTableNode:
25752576
"""Lookup a fully qualified name.
@@ -2581,10 +2582,12 @@ def lookup_fully_qualified(self, name: str) -> SymbolTableNode:
25812582
parts = name.split('.')
25822583
n = self.modules[parts[0]]
25832584
for i in range(1, len(parts) - 1):
2584-
n = cast(MypyFile, n.names[parts[i]].node)
2585-
return n.names[parts[-1]]
2585+
next_sym = n.names[parts[i]]
2586+
assert isinstance(next_sym.node, MypyFile)
2587+
n = next_sym.node
2588+
return n.names.get(parts[-1])
25862589

2587-
def lookup_fully_qualified_or_none(self, name: str) -> SymbolTableNode:
2590+
def lookup_fully_qualified_or_none(self, name: str) -> Optional[SymbolTableNode]:
25882591
"""Lookup a fully qualified name.
25892592
25902593
Assume that the name is defined. This happens in the global namespace -- the local
@@ -2597,7 +2600,8 @@ def lookup_fully_qualified_or_none(self, name: str) -> SymbolTableNode:
25972600
next_sym = n.names.get(parts[i])
25982601
if not next_sym:
25992602
return None
2600-
n = cast(MypyFile, next_sym.node)
2603+
assert isinstance(next_sym.node, MypyFile)
2604+
n = next_sym.node
26012605
return n.names.get(parts[-1])
26022606

26032607
def qualified_name(self, n: str) -> str:
@@ -2811,11 +2815,7 @@ def visit_func_def(self, func: FuncDef) -> None:
28112815
# Ah this is an imported name. We can't resolve them now, so we'll postpone
28122816
# this until the main phase of semantic analysis.
28132817
return
2814-
original_def = original_sym.node
2815-
if sem.is_conditional_func(original_def, func):
2816-
# Conditional function definition -- multiple defs are ok.
2817-
func.original_def = cast(FuncDef, original_def)
2818-
else:
2818+
if not sem.set_original_def(original_sym.node, func):
28192819
# Report error.
28202820
sem.check_no_global(func.name(), func)
28212821
else:
@@ -3055,10 +3055,11 @@ def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]:
30553055
def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
30563056
if isinstance(sig, CallableType):
30573057
return sig.copy_modified(arg_types=[new] + sig.arg_types[1:])
3058-
else:
3059-
sig = cast(Overloaded, sig)
3058+
elif isinstance(sig, Overloaded):
30603059
return Overloaded([cast(CallableType, replace_implicit_first_type(i, new))
30613060
for i in sig.items()])
3061+
else:
3062+
assert False
30623063

30633064

30643065
def set_callable_name(sig: Type, fdef: FuncDef) -> Type:

0 commit comments

Comments
 (0)