Skip to content

Commit 8453dda

Browse files
davidfstrJukkaL
authored andcommitted
TypedDict: Recognize __getitem__ and __setitem__ access. (#2526)
* TypedDict: Recognize __getitem__ and __setitem__ access.
1 parent eef1a15 commit 8453dda

File tree

5 files changed

+113
-56
lines changed

5 files changed

+113
-56
lines changed

mypy/checker.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr,
2626
RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase,
2727
AwaitExpr,
28+
ARG_POS,
2829
CONTRAVARIANT, COVARIANT)
2930
from mypy import nodes
3031
from mypy.types import (
31-
Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType,
32+
Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType, TypedDictType,
3233
Instance, NoneTyp, ErrorType, strip_type, TypeType,
3334
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType,
3435
true_only, false_only, function_type
@@ -1546,8 +1547,18 @@ def check_indexed_assignment(self, lvalue: IndexExpr,
15461547
"""
15471548
self.try_infer_partial_type_from_indexed_assignment(lvalue, rvalue)
15481549
basetype = self.accept(lvalue.base)
1549-
method_type = self.expr_checker.analyze_external_member_access(
1550-
'__setitem__', basetype, context)
1550+
if isinstance(basetype, TypedDictType):
1551+
item_type = self.expr_checker.visit_typeddict_index_expr(basetype, lvalue.index)
1552+
method_type = CallableType(
1553+
arg_types=[self.named_type('builtins.str'), item_type],
1554+
arg_kinds=[ARG_POS, ARG_POS],
1555+
arg_names=[None, None],
1556+
ret_type=NoneTyp(),
1557+
fallback=self.named_type('builtins.function')
1558+
) # type: Type
1559+
else:
1560+
method_type = self.expr_checker.analyze_external_member_access(
1561+
'__setitem__', basetype, context)
15511562
lvalue.method_type = method_type
15521563
self.expr_checker.check_call(method_type, [lvalue.index, rvalue],
15531564
[nodes.ARG_POS, nodes.ARG_POS],

mypy/checkexpr.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,8 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type:
14351435
else:
14361436
self.chk.fail(messages.TUPLE_INDEX_MUST_BE_AN_INT_LITERAL, e)
14371437
return AnyType()
1438+
elif isinstance(left_type, TypedDictType):
1439+
return self.visit_typeddict_index_expr(left_type, e.index)
14381440
else:
14391441
result, method_type = self.check_op('__getitem__', left_type, e.index, e)
14401442
e.method_type = method_type
@@ -1481,6 +1483,18 @@ def _get_value(self, index: Expression) -> Optional[int]:
14811483
return -1 * operand.value
14821484
return None
14831485

1486+
def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression):
1487+
if not isinstance(index, (StrExpr, UnicodeExpr)):
1488+
self.msg.typeddict_item_name_must_be_string_literal(td_type, index)
1489+
return AnyType()
1490+
item_name = index.value
1491+
1492+
item_type = td_type.items.get(item_name)
1493+
if item_type is None:
1494+
self.msg.typeddict_item_name_not_found(td_type, item_name, index)
1495+
return AnyType()
1496+
return item_type
1497+
14841498
def visit_cast_expr(self, expr: CastExpr) -> Type:
14851499
"""Type check a cast expression."""
14861500
source_type = self.accept(expr.expr, context=AnyType())

mypy/checkmember.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
Overloaded, TypeVarType, UnionType, PartialType,
88
DeletedType, NoneTyp, TypeType, function_type
99
)
10-
from mypy.nodes import TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr
11-
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2
12-
from mypy.nodes import Decorator, OverloadedFuncDef
10+
from mypy.nodes import (
11+
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr,
12+
ARG_POS, ARG_STAR, ARG_STAR2,
13+
Decorator, OverloadedFuncDef,
14+
)
1315
from mypy.messages import MessageBuilder
1416
from mypy.maptype import map_instance_to_supertype
1517
from mypy.expandtype import expand_type_by_instance, expand_type

mypy/messages.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,19 @@ def typeddict_instantiated_with_unexpected_items(self,
819819
self.fail('Expected items {} but found {}.'.format(
820820
expected_item_names, actual_item_names), context)
821821

822+
def typeddict_item_name_must_be_string_literal(self,
823+
typ: TypedDictType,
824+
context: Context):
825+
self.fail('Cannot prove expression is a valid item name; expected one of {}'.format(
826+
format_item_name_list(typ.items.keys())), context)
827+
828+
def typeddict_item_name_not_found(self,
829+
typ: TypedDictType,
830+
item_name: str,
831+
context: Context):
832+
self.fail('\'{}\' is not a valid item name; expected one of {}'.format(
833+
item_name, format_item_name_list(typ.items.keys())), context)
834+
822835

823836
def capitalize(s: str) -> str:
824837
"""Capitalize the first character of a string."""
@@ -862,6 +875,14 @@ def format_string_list(s: Iterable[str]) -> str:
862875
return '%s, ... and %s (%i methods suppressed)' % (', '.join(l[:2]), l[-1], len(l) - 3)
863876

864877

878+
def format_item_name_list(s: Iterable[str]) -> str:
879+
l = list(s)
880+
if len(l) <= 5:
881+
return '[' + ', '.join(["'%s'" % name for name in l]) + ']'
882+
else:
883+
return '[' + ', '.join(["'%s'" % name for name in l[:5]]) + ', ...]'
884+
885+
865886
def callable_name(type: CallableType) -> str:
866887
if type.name:
867888
return type.name

test-data/unit/check-typeddict.test

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -360,64 +360,73 @@ reveal_type(f(g)) # E: Revealed type is '<uninhabited>'
360360

361361
-- Special Method: __getitem__
362362

363-
-- TODO: Implement support for this case.
364-
--[case testCanGetItemOfTypedDictWithValidStringLiteralKey]
365-
--from mypy_extensions import TypedDict
366-
--TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
367-
--p = TaggedPoint(type='2d', x=42, y=1337)
368-
--def get_x(p: TaggedPoint) -> int:
369-
-- return p['x']
370-
--[builtins fixtures/dict.pyi]
363+
[case testCanGetItemOfTypedDictWithValidStringLiteralKey]
364+
from mypy_extensions import TypedDict
365+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
366+
p = TaggedPoint(type='2d', x=42, y=1337)
367+
reveal_type(p['type']) # E: Revealed type is 'builtins.str'
368+
reveal_type(p['x']) # E: Revealed type is 'builtins.int'
369+
reveal_type(p['y']) # E: Revealed type is 'builtins.int'
370+
[builtins fixtures/dict.pyi]
371371

372-
-- TODO: Implement support for this case.
373-
--[case testCannotGetItemOfTypedDictWithInvalidStringLiteralKey]
374-
--from mypy_extensions import TypedDict
375-
--TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
376-
--p = TaggedPoint(type='2d', x=42, y=1337)
377-
--def get_z(p: TaggedPoint) -> int:
378-
-- return p['z'] # E: ... 'z' is not a valid key for Point. Expected one of {'x', 'y'}.
379-
--[builtins fixtures/dict.pyi]
372+
[case testCanGetItemOfTypedDictWithValidBytesOrUnicodeLiteralKey]
373+
# flags: --python-version 2.7
374+
from mypy_extensions import TypedDict
375+
Cell = TypedDict('Cell', {'value': int})
376+
c = Cell(value=42)
377+
reveal_type(c['value']) # E: Revealed type is 'builtins.int'
378+
reveal_type(c[u'value']) # E: Revealed type is 'builtins.int'
379+
[builtins_py2 fixtures/dict.pyi]
380380

381-
-- TODO: Implement support for this case.
382-
--[case testCannotGetItemOfTypedDictWithNonLiteralKey]
383-
--from mypy_extensions import TypedDict
384-
--from typing import Union
385-
--TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
386-
--p = TaggedPoint(type='2d', x=42, y=1337)
387-
--def get_coordinate(p: TaggedPoint, key: str) -> Union[str, int]:
388-
-- return p[key] # E: ... Cannot prove 'key' is a valid key for Point. Expected one of {'x', 'y'}
389-
--[builtins fixtures/dict.pyi]
381+
[case testCannotGetItemOfTypedDictWithInvalidStringLiteralKey]
382+
from mypy_extensions import TypedDict
383+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
384+
p = TaggedPoint(type='2d', x=42, y=1337)
385+
p['z'] # E: 'z' is not a valid item name; expected one of ['type', 'x', 'y']
386+
[builtins fixtures/dict.pyi]
387+
388+
[case testCannotGetItemOfTypedDictWithNonLiteralKey]
389+
from mypy_extensions import TypedDict
390+
from typing import Union
391+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
392+
p = TaggedPoint(type='2d', x=42, y=1337)
393+
def get_coordinate(p: TaggedPoint, key: str) -> Union[str, int]:
394+
return p[key] # E: Cannot prove expression is a valid item name; expected one of ['type', 'x', 'y']
395+
[builtins fixtures/dict.pyi]
390396

391397

392398
-- Special Method: __setitem__
393399

394-
-- TODO: Implement support for this case.
395-
--[case testCanSetItemOfTypedDictWithValidStringLiteralKey]
396-
--from mypy_extensions import TypedDict
397-
--TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
398-
--p = TaggedPoint(type='2d', x=42, y=1337)
399-
--def set_x(p: TaggedPoint, x: int) -> None:
400-
-- p['x'] = x
401-
--[builtins fixtures/dict.pyi]
400+
[case testCanSetItemOfTypedDictWithValidStringLiteralKeyAndCompatibleValueType]
401+
from mypy_extensions import TypedDict
402+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
403+
p = TaggedPoint(type='2d', x=42, y=1337)
404+
p['type'] = 'two_d'
405+
p['x'] = 1
406+
[builtins fixtures/dict.pyi]
402407

403-
-- TODO: Implement support for this case.
404-
--[case testCannotSetItemOfTypedDictWithInvalidStringLiteralKey]
405-
--from mypy_extensions import TypedDict
406-
--TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
407-
--p = TaggedPoint(type='2d', x=42, y=1337)
408-
--def set_z(p: TaggedPoint, z: int) -> None:
409-
-- p['z'] = z # E: ... 'z' is not a valid key for Point. Expected one of {'x', 'y'}.
410-
--[builtins fixtures/dict.pyi]
408+
[case testCannotSetItemOfTypedDictWithIncompatibleValueType]
409+
from mypy_extensions import TypedDict
410+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
411+
p = TaggedPoint(type='2d', x=42, y=1337)
412+
p['x'] = 'y' # E: Argument 2 has incompatible type "str"; expected "int"
413+
[builtins fixtures/dict.pyi]
411414

412-
-- TODO: Implement support for this case.
413-
--[case testCannotSetItemOfTypedDictWithNonLiteralKey]
414-
--from mypy_extensions import TypedDict
415-
--from typing import Union
416-
--TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
417-
--p = TaggedPoint(type='2d', x=42, y=1337)
418-
--def set_coordinate(p: TaggedPoint, key: str, value: Union[str, int]) -> None:
419-
-- p[key] = value # E: ... Cannot prove 'key' is a valid key for Point. Expected one of {'x', 'y'}
420-
--[builtins fixtures/dict.pyi]
415+
[case testCannotSetItemOfTypedDictWithInvalidStringLiteralKey]
416+
from mypy_extensions import TypedDict
417+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
418+
p = TaggedPoint(type='2d', x=42, y=1337)
419+
p['z'] = 1 # E: 'z' is not a valid item name; expected one of ['type', 'x', 'y']
420+
[builtins fixtures/dict.pyi]
421+
422+
[case testCannotSetItemOfTypedDictWithNonLiteralKey]
423+
from mypy_extensions import TypedDict
424+
from typing import Union
425+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
426+
p = TaggedPoint(type='2d', x=42, y=1337)
427+
def set_coordinate(p: TaggedPoint, key: str, value: int) -> None:
428+
p[key] = value # E: Cannot prove expression is a valid item name; expected one of ['type', 'x', 'y']
429+
[builtins fixtures/dict.pyi]
421430

422431

423432
-- Special Method: get

0 commit comments

Comments
 (0)