Skip to content

Commit e7ddba1

Browse files
authored
Add union math for intelligent indexing (#6558)
This pull request hacks on supports for union math when using Unions of Literals to index into tuples, NamedTuples, and TypedDicts. It also fixes a bug I apparently introduced. Currently, mypy correctly reports an error with this code: class Test(TypedDict): foo: int t: Test # Error: int can't be assigned a str value t.setdefault("foo", "unrelated value") ...but does not report an error with: key: Literal["foo"] t.setdefault(key, "unrelated value") This diff should make mypy report an error in both cases. Resolves #6262.
1 parent 2f613b8 commit e7ddba1

File tree

6 files changed

+318
-103
lines changed

6 files changed

+318
-103
lines changed

mypy/checkexpr.py

Lines changed: 87 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections import OrderedDict
44
from contextlib import contextmanager
5+
import itertools
56
from typing import (
67
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Sequence, Iterator
78
)
@@ -2554,15 +2555,18 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
25542555
if isinstance(index, SliceExpr):
25552556
return self.visit_tuple_slice_helper(left_type, index)
25562557

2557-
n = self._get_value(index)
2558-
if n is not None:
2559-
if n < 0:
2560-
n += len(left_type.items)
2561-
if 0 <= n < len(left_type.items):
2562-
return left_type.items[n]
2563-
else:
2564-
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
2565-
return AnyType(TypeOfAny.from_error)
2558+
ns = self.try_getting_int_literals(index)
2559+
if ns is not None:
2560+
out = []
2561+
for n in ns:
2562+
if n < 0:
2563+
n += len(left_type.items)
2564+
if 0 <= n < len(left_type.items):
2565+
out.append(left_type.items[n])
2566+
else:
2567+
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
2568+
return AnyType(TypeOfAny.from_error)
2569+
return UnionType.make_simplified_union(out)
25662570
else:
25672571
return self.nonliteral_tuple_index_helper(left_type, index)
25682572
elif isinstance(left_type, TypedDictType):
@@ -2578,26 +2582,66 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
25782582
return result
25792583

25802584
def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Type:
2581-
begin = None
2582-
end = None
2583-
stride = None
2585+
begin = [None] # type: Sequence[Optional[int]]
2586+
end = [None] # type: Sequence[Optional[int]]
2587+
stride = [None] # type: Sequence[Optional[int]]
25842588

25852589
if slic.begin_index:
2586-
begin = self._get_value(slic.begin_index)
2587-
if begin is None:
2590+
begin_raw = self.try_getting_int_literals(slic.begin_index)
2591+
if begin_raw is None:
25882592
return self.nonliteral_tuple_index_helper(left_type, slic)
2593+
begin = begin_raw
25892594

25902595
if slic.end_index:
2591-
end = self._get_value(slic.end_index)
2592-
if end is None:
2596+
end_raw = self.try_getting_int_literals(slic.end_index)
2597+
if end_raw is None:
25932598
return self.nonliteral_tuple_index_helper(left_type, slic)
2599+
end = end_raw
25942600

25952601
if slic.stride:
2596-
stride = self._get_value(slic.stride)
2597-
if stride is None:
2602+
stride_raw = self.try_getting_int_literals(slic.stride)
2603+
if stride_raw is None:
25982604
return self.nonliteral_tuple_index_helper(left_type, slic)
2605+
stride = stride_raw
2606+
2607+
items = [] # type: List[Type]
2608+
for b, e, s in itertools.product(begin, end, stride):
2609+
items.append(left_type.slice(b, e, s))
2610+
return UnionType.make_simplified_union(items)
25992611

2600-
return left_type.slice(begin, stride, end)
2612+
def try_getting_int_literals(self, index: Expression) -> Optional[List[int]]:
2613+
"""If the given expression or type corresponds to an int literal
2614+
or a union of int literals, returns a list of the underlying ints.
2615+
Otherwise, returns None.
2616+
2617+
Specifically, this function is guaranteed to return a list with
2618+
one or more ints if one one the following is true:
2619+
2620+
1. 'expr' is a IntExpr or a UnaryExpr backed by an IntExpr
2621+
2. 'typ' is a LiteralType containing an int
2622+
3. 'typ' is a UnionType containing only LiteralType of ints
2623+
"""
2624+
if isinstance(index, IntExpr):
2625+
return [index.value]
2626+
elif isinstance(index, UnaryExpr):
2627+
if index.op == '-':
2628+
operand = index.expr
2629+
if isinstance(operand, IntExpr):
2630+
return [-1 * operand.value]
2631+
typ = self.accept(index)
2632+
if isinstance(typ, Instance) and typ.last_known_value is not None:
2633+
typ = typ.last_known_value
2634+
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
2635+
return [typ.value]
2636+
if isinstance(typ, UnionType):
2637+
out = []
2638+
for item in typ.items:
2639+
if isinstance(item, LiteralType) and isinstance(item.value, int):
2640+
out.append(item.value)
2641+
else:
2642+
return None
2643+
return out
2644+
return None
26012645

26022646
def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) -> Type:
26032647
index_type = self.accept(index)
@@ -2614,40 +2658,36 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)
26142658
else:
26152659
return union
26162660

2617-
def _get_value(self, index: Expression) -> Optional[int]:
2618-
if isinstance(index, IntExpr):
2619-
return index.value
2620-
elif isinstance(index, UnaryExpr):
2621-
if index.op == '-':
2622-
operand = index.expr
2623-
if isinstance(operand, IntExpr):
2624-
return -1 * operand.value
2625-
typ = self.accept(index)
2626-
if isinstance(typ, Instance) and typ.last_known_value is not None:
2627-
typ = typ.last_known_value
2628-
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
2629-
return typ.value
2630-
return None
2631-
26322661
def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type:
26332662
if isinstance(index, (StrExpr, UnicodeExpr)):
2634-
item_name = index.value
2663+
key_names = [index.value]
26352664
else:
26362665
typ = self.accept(index)
2637-
if isinstance(typ, Instance) and typ.last_known_value is not None:
2638-
typ = typ.last_known_value
2639-
2640-
if isinstance(typ, LiteralType) and isinstance(typ.value, str):
2641-
item_name = typ.value
2666+
if isinstance(typ, UnionType):
2667+
key_types = typ.items
26422668
else:
2643-
self.msg.typeddict_key_must_be_string_literal(td_type, index)
2644-
return AnyType(TypeOfAny.from_error)
2669+
key_types = [typ]
26452670

2646-
item_type = td_type.items.get(item_name)
2647-
if item_type is None:
2648-
self.msg.typeddict_key_not_found(td_type, item_name, index)
2649-
return AnyType(TypeOfAny.from_error)
2650-
return item_type
2671+
key_names = []
2672+
for key_type in key_types:
2673+
if isinstance(key_type, Instance) and key_type.last_known_value is not None:
2674+
key_type = key_type.last_known_value
2675+
2676+
if isinstance(key_type, LiteralType) and isinstance(key_type.value, str):
2677+
key_names.append(key_type.value)
2678+
else:
2679+
self.msg.typeddict_key_must_be_string_literal(td_type, index)
2680+
return AnyType(TypeOfAny.from_error)
2681+
2682+
value_types = []
2683+
for key_name in key_names:
2684+
value_type = td_type.items.get(key_name)
2685+
if value_type is None:
2686+
self.msg.typeddict_key_not_found(td_type, key_name, index)
2687+
return AnyType(TypeOfAny.from_error)
2688+
else:
2689+
value_types.append(value_type)
2690+
return UnionType.make_simplified_union(value_types)
26512691

26522692
def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression,
26532693
context: Context) -> Type:

mypy/messages.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,14 @@ def typeddict_key_cannot_be_deleted(
11771177
self.fail("Key '{}' of TypedDict {} cannot be deleted".format(
11781178
item_name, self.format(typ)), context)
11791179

1180+
def typeddict_setdefault_arguments_inconsistent(
1181+
self,
1182+
default: Type,
1183+
expected: Type,
1184+
context: Context) -> None:
1185+
msg = 'Argument 2 to "setdefault" of "TypedDict" has incompatible type {}; expected {}'
1186+
self.fail(msg.format(self.format(default), self.format(expected)), context)
1187+
11801188
def type_arguments_not_allowed(self, context: Context) -> None:
11811189
self.fail('Parameterized generics cannot be used with class or instance checks', context)
11821190

mypy/plugins/common.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from mypy.plugin import ClassDefContext
88
from mypy.semanal import set_callable_name
9-
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance
9+
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance, UnionType
1010
from mypy.typevars import fill_typevars
1111
from mypy.util import get_unique_redefinition_name
1212

@@ -129,18 +129,34 @@ def add_method(
129129
info.defn.defs.body.append(func)
130130

131131

132-
def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]:
133-
"""If this expression is a string literal, or if the corresponding type
134-
is something like 'Literal["some string here"]', returns the underlying
135-
string value. Otherwise, returns None."""
132+
def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]:
133+
"""If the given expression or type corresponds to a string literal
134+
or a union of string literals, returns a list of the underlying strings.
135+
Otherwise, returns None.
136+
137+
Specifically, this function is guaranteed to return a list with
138+
one or more strings if one one the following is true:
139+
140+
1. 'expr' is a StrExpr
141+
2. 'typ' is a LiteralType containing a string
142+
3. 'typ' is a UnionType containing only LiteralType of strings
143+
"""
144+
if isinstance(expr, StrExpr):
145+
return [expr.value]
146+
136147
if isinstance(typ, Instance) and typ.last_known_value is not None:
137-
typ = typ.last_known_value
138-
139-
if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str':
140-
val = typ.value
141-
assert isinstance(val, str)
142-
return val
143-
elif isinstance(expr, StrExpr):
144-
return expr.value
148+
possible_literals = [typ.last_known_value] # type: List[Type]
149+
elif isinstance(typ, UnionType):
150+
possible_literals = typ.items
145151
else:
146-
return None
152+
possible_literals = [typ]
153+
154+
strings = []
155+
for lit in possible_literals:
156+
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
157+
val = lit.value
158+
assert isinstance(val, str)
159+
strings.append(val)
160+
else:
161+
return None
162+
return strings

0 commit comments

Comments
 (0)