Skip to content

Commit ad2d4ba

Browse files
authored
Make literal exprs have inferred type of 'Literal' based on context (#5990)
This pull request modifies the type checking logic so that literal expressions will have an inferred type of 'Literal' if the context asks for a literal type. That is, it implements support for this: x: Literal[1] = 1 y = 1 reveal_type(x) # E: Revealed type is 'Literal[1]' reveal_type(y) # E: Revealed type is 'builtins.int' This pull requests also implements the `visit_literal_type` method in the `constraints.ConstraintBuilderVisitor` and `join.TypeJoinVisitor` methods. Both visitors are exercised indirectly through the "let's use literal types in collection contexts" code, but only the latter is tested directly: I wasn't really sure how to directly test `ConstraintBuilderVisitor`. The implementation is simple though -- I'm pretty sure literal types count as a "leaf type" so it's fine to return an empty list (no constraints).
1 parent 1c824b6 commit ad2d4ba

File tree

6 files changed

+559
-24
lines changed

6 files changed

+559
-24
lines changed

mypy/checkexpr.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
210210

211211
def analyze_var_ref(self, var: Var, context: Context) -> Type:
212212
if var.type:
213-
return var.type
213+
if is_literal_type_like(self.type_context[-1]) and var.name() in {'True', 'False'}:
214+
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
215+
else:
216+
return var.type
214217
else:
215218
if not var.is_ready and self.chk.in_checked_function():
216219
self.chk.handle_cannot_determine_type(var.name(), context)
@@ -1721,11 +1724,17 @@ def analyze_external_member_access(self, member: str, base_type: Type,
17211724

17221725
def visit_int_expr(self, e: IntExpr) -> Type:
17231726
"""Type check an integer literal (trivial)."""
1724-
return self.named_type('builtins.int')
1727+
typ = self.named_type('builtins.int')
1728+
if is_literal_type_like(self.type_context[-1]):
1729+
return LiteralType(value=e.value, fallback=typ)
1730+
return typ
17251731

17261732
def visit_str_expr(self, e: StrExpr) -> Type:
17271733
"""Type check a string literal (trivial)."""
1728-
return self.named_type('builtins.str')
1734+
typ = self.named_type('builtins.str')
1735+
if is_literal_type_like(self.type_context[-1]):
1736+
return LiteralType(value=e.value, fallback=typ)
1737+
return typ
17291738

17301739
def visit_bytes_expr(self, e: BytesExpr) -> Type:
17311740
"""Type check a bytes literal (trivial)."""
@@ -3583,3 +3592,17 @@ def merge_typevars_in_callables_by_name(
35833592
output.append(target)
35843593

35853594
return output, variables
3595+
3596+
3597+
def is_literal_type_like(t: Optional[Type]) -> bool:
3598+
"""Returns 'true' if the given type context is potentially either a LiteralType,
3599+
a Union of LiteralType, or something similar.
3600+
"""
3601+
if t is None:
3602+
return False
3603+
elif isinstance(t, LiteralType):
3604+
return True
3605+
elif isinstance(t, UnionType):
3606+
return any(is_literal_type_like(item) for item in t.items)
3607+
else:
3608+
return False

mypy/constraints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,9 @@ def visit_erased_type(self, template: ErasedType) -> List[Constraint]:
260260
def visit_deleted_type(self, template: DeletedType) -> List[Constraint]:
261261
return []
262262

263+
def visit_literal_type(self, template: LiteralType) -> List[Constraint]:
264+
return []
265+
263266
# Errors
264267

265268
def visit_partial_type(self, template: PartialType) -> List[Constraint]:
@@ -472,9 +475,6 @@ def visit_typeddict_type(self, template: TypedDictType) -> List[Constraint]:
472475
else:
473476
return []
474477

475-
def visit_literal_type(self, template: LiteralType) -> List[Constraint]:
476-
raise NotImplementedError()
477-
478478
def visit_union_type(self, template: UnionType) -> List[Constraint]:
479479
assert False, ("Unexpected UnionType in ConstraintBuilderVisitor"
480480
" (should have been handled in infer_constraints)")

mypy/join.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def visit_instance(self, t: Instance) -> Type:
163163
return join_types(t, self.s)
164164
elif isinstance(self.s, TypedDictType):
165165
return join_types(t, self.s)
166+
elif isinstance(self.s, LiteralType):
167+
return join_types(t, self.s)
166168
else:
167169
return self.default(self.s)
168170

@@ -268,7 +270,13 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
268270
return self.default(self.s)
269271

270272
def visit_literal_type(self, t: LiteralType) -> Type:
271-
raise NotImplementedError()
273+
if isinstance(self.s, LiteralType):
274+
if t == self.s:
275+
return t
276+
else:
277+
return join_types(self.s.fallback, t.fallback)
278+
else:
279+
return join_types(self.s, t.fallback)
272280

273281
def visit_partial_type(self, t: PartialType) -> Type:
274282
# We only have partial information so we can't decide the join result. We should

mypy/test/testtypes.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,11 @@ def test_is_proper_subtype_and_subtype_literal_types(self) -> None:
249249
fx = self.fx
250250

251251
lit1 = LiteralType(1, fx.a)
252-
lit2 = LiteralType("foo", fx.b)
253-
lit3 = LiteralType("bar", fx.b)
252+
lit2 = LiteralType("foo", fx.d)
253+
lit3 = LiteralType("bar", fx.d)
254254

255255
assert_true(is_proper_subtype(lit1, fx.a))
256-
assert_false(is_proper_subtype(lit1, fx.b))
256+
assert_false(is_proper_subtype(lit1, fx.d))
257257
assert_false(is_proper_subtype(fx.a, lit1))
258258
assert_true(is_proper_subtype(fx.uninhabited, lit1))
259259
assert_false(is_proper_subtype(lit1, fx.uninhabited))
@@ -262,7 +262,7 @@ def test_is_proper_subtype_and_subtype_literal_types(self) -> None:
262262
assert_false(is_proper_subtype(lit2, lit3))
263263

264264
assert_true(is_subtype(lit1, fx.a))
265-
assert_false(is_subtype(lit1, fx.b))
265+
assert_false(is_subtype(lit1, fx.d))
266266
assert_false(is_subtype(fx.a, lit1))
267267
assert_true(is_subtype(fx.uninhabited, lit1))
268268
assert_false(is_subtype(lit1, fx.uninhabited))
@@ -621,6 +621,41 @@ def test_type_type(self) -> None:
621621
self.assert_join(self.fx.type_type, self.fx.type_any, self.fx.type_type)
622622
self.assert_join(self.fx.type_b, self.fx.anyt, self.fx.anyt)
623623

624+
def test_literal_type(self) -> None:
625+
a = self.fx.a
626+
d = self.fx.d
627+
lit1 = LiteralType(1, a)
628+
lit2 = LiteralType(2, a)
629+
lit3 = LiteralType("foo", d)
630+
631+
self.assert_join(lit1, lit1, lit1)
632+
self.assert_join(lit1, a, a)
633+
self.assert_join(lit1, d, self.fx.o)
634+
self.assert_join(lit1, lit2, a)
635+
self.assert_join(lit1, lit3, self.fx.o)
636+
self.assert_join(lit1, self.fx.anyt, self.fx.anyt)
637+
self.assert_join(UnionType([lit1, lit2]), lit2, UnionType([lit1, lit2]))
638+
self.assert_join(UnionType([lit1, lit2]), a, a)
639+
self.assert_join(UnionType([lit1, lit3]), a, UnionType([a, lit3]))
640+
self.assert_join(UnionType([d, lit3]), lit3, UnionType([d, lit3]))
641+
self.assert_join(UnionType([d, lit3]), d, UnionType([d, lit3]))
642+
self.assert_join(UnionType([a, lit1]), lit1, UnionType([a, lit1]))
643+
self.assert_join(UnionType([a, lit1]), lit2, UnionType([a, lit1]))
644+
self.assert_join(UnionType([lit1, lit2]),
645+
UnionType([lit1, lit2]),
646+
UnionType([lit1, lit2]))
647+
648+
# The order in which we try joining two unions influences the
649+
# ordering of the items in the final produced unions. So, we
650+
# manually call 'assert_simple_join' and tune the output
651+
# after swapping the arguments here.
652+
self.assert_simple_join(UnionType([lit1, lit2]),
653+
UnionType([lit2, lit3]),
654+
UnionType([lit1, lit2, lit3]))
655+
self.assert_simple_join(UnionType([lit2, lit3]),
656+
UnionType([lit1, lit2]),
657+
UnionType([lit2, lit3, lit1]))
658+
624659
# There are additional test cases in check-inference.test.
625660

626661
# TODO: Function types + varargs and default args.

mypy/typeanal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,8 @@ def visit_unbound_type(self, t: UnboundType) -> TypeVarList:
10861086
return [(name, node.node)]
10871087
elif not self.include_callables and self._seems_like_callable(t):
10881088
return []
1089+
elif node and node.fullname in ('typing_extensions.Literal', 'typing.Literal'):
1090+
return []
10891091
else:
10901092
return super().visit_unbound_type(t)
10911093

0 commit comments

Comments
 (0)