Skip to content

Commit d198f38

Browse files
committed
Add union math for intelligent indexing
This pull request hacks on supports for union math when using Unions of Literals to index into tuples, NamedTuples, and TypedDicts. It also fixes a bug I apparently introduced. Currently, mypy correctly reports an error with this code: class Test(TypedDict): foo: int t: Test # Error: int can't be assigned a str value t.setdefault("foo", "unrelated value") ...but does not report an error with: key: Literal["foo"] t.setdefault(key, "unrelated value") This diff should make mypy report an error in both cases. Resolves #6262.
1 parent dfcce2b commit d198f38

File tree

6 files changed

+259
-91
lines changed

6 files changed

+259
-91
lines changed

mypy/checkexpr.py

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections import OrderedDict
44
from contextlib import contextmanager
5+
import itertools
56
from typing import (
67
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable,
78
Sequence, Iterator
@@ -2465,15 +2466,18 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
24652466
if isinstance(index, SliceExpr):
24662467
return self.visit_tuple_slice_helper(left_type, index)
24672468

2468-
n = self._get_value(index)
2469-
if n is not None:
2470-
if n < 0:
2471-
n += len(left_type.items)
2472-
if 0 <= n < len(left_type.items):
2473-
return left_type.items[n]
2474-
else:
2475-
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
2476-
return AnyType(TypeOfAny.from_error)
2469+
ns = self._get_values(index)
2470+
if ns is not None:
2471+
out = []
2472+
for n in ns:
2473+
if n < 0:
2474+
n += len(left_type.items)
2475+
if 0 <= n < len(left_type.items):
2476+
out.append(left_type.items[n])
2477+
else:
2478+
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
2479+
return AnyType(TypeOfAny.from_error)
2480+
return UnionType.make_simplified_union(out)
24772481
else:
24782482
return self.nonliteral_tuple_index_helper(left_type, index)
24792483
elif isinstance(left_type, TypedDictType):
@@ -2489,26 +2493,32 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
24892493
return result
24902494

24912495
def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Type:
2492-
begin = None
2493-
end = None
2494-
stride = None
2496+
begin = [None] # type: Sequence[Optional[int]]
2497+
end = [None] # type: Sequence[Optional[int]]
2498+
stride = [None] # type: Sequence[Optional[int]]
24952499

24962500
if slic.begin_index:
2497-
begin = self._get_value(slic.begin_index)
2498-
if begin is None:
2501+
begin_raw = self._get_values(slic.begin_index)
2502+
if begin_raw is None:
24992503
return self.nonliteral_tuple_index_helper(left_type, slic)
2504+
begin = begin_raw
25002505

25012506
if slic.end_index:
2502-
end = self._get_value(slic.end_index)
2503-
if end is None:
2507+
end_raw = self._get_values(slic.end_index)
2508+
if end_raw is None:
25042509
return self.nonliteral_tuple_index_helper(left_type, slic)
2510+
end = end_raw
25052511

25062512
if slic.stride:
2507-
stride = self._get_value(slic.stride)
2508-
if stride is None:
2513+
stride_raw = self._get_values(slic.stride)
2514+
if stride_raw is None:
25092515
return self.nonliteral_tuple_index_helper(left_type, slic)
2516+
stride = stride_raw
25102517

2511-
return left_type.slice(begin, stride, end)
2518+
items = [] # type: List[Type]
2519+
for b, e, s in itertools.product(begin, end, stride):
2520+
items.append(left_type.slice(b, e, s))
2521+
return UnionType.make_simplified_union(items)
25122522

25132523
def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) -> Type:
25142524
index_type = self.accept(index)
@@ -2521,40 +2531,60 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)
25212531
else:
25222532
return UnionType.make_simplified_union(left_type.items)
25232533

2524-
def _get_value(self, index: Expression) -> Optional[int]:
2534+
def _get_values(self, index: Expression) -> Optional[List[int]]:
25252535
if isinstance(index, IntExpr):
2526-
return index.value
2536+
return [index.value]
25272537
elif isinstance(index, UnaryExpr):
25282538
if index.op == '-':
25292539
operand = index.expr
25302540
if isinstance(operand, IntExpr):
2531-
return -1 * operand.value
2541+
return [-1 * operand.value]
25322542
typ = self.accept(index)
25332543
if isinstance(typ, Instance) and typ.final_value is not None:
25342544
typ = typ.final_value
25352545
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
2536-
return typ.value
2546+
return [typ.value]
2547+
if isinstance(typ, UnionType):
2548+
out = []
2549+
for item in typ.items:
2550+
if isinstance(item, LiteralType) and isinstance(item.value, int):
2551+
out.append(item.value)
2552+
else:
2553+
return None
2554+
return out
25372555
return None
25382556

25392557
def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type:
25402558
if isinstance(index, (StrExpr, UnicodeExpr)):
2541-
item_name = index.value
2559+
key_names = [index.value]
25422560
else:
25432561
typ = self.accept(index)
2544-
if isinstance(typ, Instance) and typ.final_value is not None:
2545-
typ = typ.final_value
25462562

2547-
if isinstance(typ, LiteralType) and isinstance(typ.value, str):
2548-
item_name = typ.value
2563+
if isinstance(typ, UnionType):
2564+
key_types = typ.items
25492565
else:
2550-
self.msg.typeddict_key_must_be_string_literal(td_type, index)
2551-
return AnyType(TypeOfAny.from_error)
2566+
key_types = [typ]
25522567

2553-
item_type = td_type.items.get(item_name)
2554-
if item_type is None:
2555-
self.msg.typeddict_key_not_found(td_type, item_name, index)
2556-
return AnyType(TypeOfAny.from_error)
2557-
return item_type
2568+
key_names = []
2569+
for key_type in key_types:
2570+
if isinstance(key_type, Instance) and key_type.final_value is not None:
2571+
key_type = key_type.final_value
2572+
2573+
if isinstance(key_type, LiteralType) and isinstance(key_type.value, str):
2574+
key_names.append(key_type.value)
2575+
else:
2576+
self.msg.typeddict_key_must_be_string_literal(td_type, index)
2577+
return AnyType(TypeOfAny.from_error)
2578+
2579+
value_types = []
2580+
for key_name in key_names:
2581+
value_type = td_type.items.get(key_name)
2582+
if value_type is None:
2583+
self.msg.typeddict_key_not_found(td_type, key_name, index)
2584+
return AnyType(TypeOfAny.from_error)
2585+
else:
2586+
value_types.append(value_type)
2587+
return UnionType.make_simplified_union(value_types)
25582588

25592589
def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression,
25602590
context: Context) -> Type:

mypy/messages.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,14 @@ def typeddict_key_cannot_be_deleted(
11431143
self.fail("Key '{}' of TypedDict {} cannot be deleted".format(
11441144
item_name, self.format(typ)), context)
11451145

1146+
def typeddict_setdefault_arguments_inconsistent(
1147+
self,
1148+
default: Type,
1149+
expected: Type,
1150+
context: Context) -> None:
1151+
msg = 'Argument 2 to "setdefault" of "TypedDict" has incompatible type {}; expected {}'
1152+
self.fail(msg.format(self.format(default), self.format(expected)), context)
1153+
11461154
def type_arguments_not_allowed(self, context: Context) -> None:
11471155
self.fail('Parameterized generics cannot be used with class or instance checks', context)
11481156

mypy/plugins/common.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from mypy.plugin import ClassDefContext
88
from mypy.semanal import set_callable_name
9-
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance
9+
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance, UnionType
1010
from mypy.typevars import fill_typevars
1111

1212

@@ -113,18 +113,34 @@ def add_method(
113113
info.defn.defs.body.append(func)
114114

115115

116-
def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]:
117-
"""If this expression is a string literal, or if the corresponding type
118-
is something like 'Literal["some string here"]', returns the underlying
119-
string value. Otherwise, returns None."""
116+
def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]:
117+
"""If the given expression or type corresponds to a string literal
118+
or a union of string literals, returns a list of the underlying strings.
119+
Otherwise, returns None.
120+
121+
Specifically, this function is guaranteed to return a list with
122+
one or more strings if one one the following is true:
123+
124+
1. 'expr' is a StrExpr
125+
2. 'typ' is a LiteralType containing a string
126+
3. 'typ' is a UnionType containing only LiteralType of strings
127+
"""
128+
if isinstance(expr, StrExpr):
129+
return [expr.value]
130+
120131
if isinstance(typ, Instance) and typ.final_value is not None:
121-
typ = typ.final_value
122-
123-
if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str':
124-
val = typ.value
125-
assert isinstance(val, str)
126-
return val
127-
elif isinstance(expr, StrExpr):
128-
return expr.value
132+
possible_literals = [typ.final_value] # type: List[Type]
133+
elif isinstance(typ, UnionType):
134+
possible_literals = typ.items
129135
else:
130-
return None
136+
possible_literals = [typ]
137+
138+
strings = []
139+
for lit in possible_literals:
140+
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
141+
val = lit.value
142+
assert isinstance(val, str)
143+
strings.append(val)
144+
else:
145+
return None
146+
return strings

mypy/plugins/default.py

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from mypy.plugin import (
77
Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext
88
)
9-
from mypy.plugins.common import try_getting_str_literal
9+
from mypy.plugins.common import try_getting_str_literals
1010
from mypy.types import (
1111
Type, Instance, AnyType, TypeOfAny, CallableType, NoneTyp, UnionType, TypedDictType,
1212
TypeVarType
1313
)
14+
from mypy.subtypes import is_subtype
1415

1516

1617
class DefaultPlugin(Plugin):
@@ -171,26 +172,34 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
171172
if (isinstance(ctx.type, TypedDictType)
172173
and len(ctx.arg_types) >= 1
173174
and len(ctx.arg_types[0]) == 1):
174-
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
175-
if key is None:
175+
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
176+
if keys is None:
176177
return ctx.default_return_type
177178

178-
value_type = ctx.type.items.get(key)
179-
if value_type:
179+
output_types = []
180+
for key in keys:
181+
value_type = ctx.type.items.get(key)
182+
if value_type is None:
183+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
184+
return AnyType(TypeOfAny.from_error)
185+
180186
if len(ctx.arg_types) == 1:
181-
return UnionType.make_simplified_union([value_type, NoneTyp()])
187+
output_types.append(value_type)
182188
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
183189
and len(ctx.args[1]) == 1):
184190
default_arg = ctx.args[1][0]
185191
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
186192
and isinstance(value_type, TypedDictType)):
187193
# Special case '{}' as the default for a typed dict type.
188-
return value_type.copy_modified(required_keys=set())
194+
output_types.append(value_type.copy_modified(required_keys=set()))
189195
else:
190-
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
191-
else:
192-
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
193-
return AnyType(TypeOfAny.from_error)
196+
output_types.append(value_type)
197+
output_types.append(ctx.arg_types[1][0])
198+
199+
if len(ctx.arg_types) == 1:
200+
output_types.append(NoneTyp())
201+
202+
return UnionType.make_simplified_union(output_types)
194203
return ctx.default_return_type
195204

196205

@@ -228,23 +237,28 @@ def typed_dict_pop_callback(ctx: MethodContext) -> Type:
228237
if (isinstance(ctx.type, TypedDictType)
229238
and len(ctx.arg_types) >= 1
230239
and len(ctx.arg_types[0]) == 1):
231-
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
232-
if key is None:
240+
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
241+
if keys is None:
233242
ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
234243
return AnyType(TypeOfAny.from_error)
235244

236-
if key in ctx.type.required_keys:
237-
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
238-
value_type = ctx.type.items.get(key)
239-
if value_type:
240-
if len(ctx.args[1]) == 0:
241-
return value_type
242-
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
243-
and len(ctx.args[1]) == 1):
244-
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
245-
else:
246-
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
247-
return AnyType(TypeOfAny.from_error)
245+
value_types = []
246+
for key in keys:
247+
if key in ctx.type.required_keys:
248+
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
249+
250+
value_type = ctx.type.items.get(key)
251+
if value_type:
252+
value_types.append(value_type)
253+
else:
254+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
255+
return AnyType(TypeOfAny.from_error)
256+
257+
if len(ctx.args[1]) == 0:
258+
return UnionType.make_simplified_union(value_types)
259+
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
260+
and len(ctx.args[1]) == 1):
261+
return UnionType.make_simplified_union([*value_types, ctx.arg_types[1][0]])
248262
return ctx.default_return_type
249263

250264

@@ -273,18 +287,35 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
273287
"""Type check TypedDict.setdefault and infer a precise return type."""
274288
if (isinstance(ctx.type, TypedDictType)
275289
and len(ctx.arg_types) == 2
276-
and len(ctx.arg_types[0]) == 1):
277-
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
278-
if key is None:
290+
and len(ctx.arg_types[0]) == 1
291+
and len(ctx.arg_types[1]) == 1):
292+
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
293+
if keys is None:
279294
ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
280295
return AnyType(TypeOfAny.from_error)
281296

282-
value_type = ctx.type.items.get(key)
283-
if value_type:
284-
return value_type
285-
else:
286-
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
287-
return AnyType(TypeOfAny.from_error)
297+
default_type = ctx.arg_types[1][0]
298+
299+
value_types = []
300+
for key in keys:
301+
value_type = ctx.type.items.get(key)
302+
303+
if value_type is None:
304+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
305+
return AnyType(TypeOfAny.from_error)
306+
307+
# The signature_callback above can't always infer the right signature
308+
# (e.g. when the expression is a variable that happens to be a Literal str)
309+
# so we need to handle the check ourselves here and make sure the provided
310+
# default can be assigned to all key-value pairs we're updating.
311+
if not is_subtype(default_type, value_type):
312+
ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
313+
default_type, value_type, ctx.context)
314+
return AnyType(TypeOfAny.from_error)
315+
316+
value_types.append(value_type)
317+
318+
return UnionType.make_simplified_union(value_types)
288319
return ctx.default_return_type
289320

290321

@@ -299,15 +330,16 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
299330
if (isinstance(ctx.type, TypedDictType)
300331
and len(ctx.arg_types) == 1
301332
and len(ctx.arg_types[0]) == 1):
302-
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
303-
if key is None:
333+
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
334+
if keys is None:
304335
ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
305336
return AnyType(TypeOfAny.from_error)
306337

307-
if key in ctx.type.required_keys:
308-
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
309-
elif key not in ctx.type.items:
310-
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
338+
for key in keys:
339+
if key in ctx.type.required_keys:
340+
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
341+
elif key not in ctx.type.items:
342+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
311343
return ctx.default_return_type
312344

313345

0 commit comments

Comments
 (0)