Skip to content

Commit e818a96

Browse files
Michael0x2ailevkivskyi
authored andcommitted
Generalize reachability checks to support enums (#7000)
Fixes #1803 This diff adds support for performing reachability and narrowing analysis when doing certain enum checks. For example, given the following enum: class Foo(Enum): A = 1 B = 2 ...this pull request will make mypy do the following: x: Foo if x is Foo.A: reveal_type(x) # type: Literal[Foo.A] elif x is Foo.B: reveal_type(x) # type: Literal[Foo.B] else: reveal_type(x) # No output: branch inferred as unreachable This diff does not attempt to perform this same sort of narrowing for equality checks: I suspect implementing those will be harder due to their overridable nature. (E.g. you can define custom `__eq__` methods within Enum subclasses). This pull request also finally adds support for the enum behavior [described in PEP 484][0] and also sort of partially addresses #6366 [0]: https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
1 parent 028f202 commit e818a96

File tree

2 files changed

+344
-11
lines changed

2 files changed

+344
-11
lines changed

mypy/checker.py

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import itertools
44
import fnmatch
5+
import sys
56
from contextlib import contextmanager
67

78
from typing import (
@@ -3535,21 +3536,34 @@ def find_isinstance_check(self, node: Expression
35353536
vartype = type_map[expr]
35363537
return self.conditional_callable_type_map(expr, vartype)
35373538
elif isinstance(node, ComparisonExpr):
3538-
# Check for `x is None` and `x is not None`.
3539+
operand_types = [coerce_to_literal(type_map[expr])
3540+
for expr in node.operands if expr in type_map]
3541+
35393542
is_not = node.operators == ['is not']
3540-
if any(is_literal_none(n) for n in node.operands) and (
3541-
is_not or node.operators == ['is']):
3543+
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):
35423544
if_vars = {} # type: TypeMap
35433545
else_vars = {} # type: TypeMap
3544-
for expr in node.operands:
3545-
if (literal(expr) == LITERAL_TYPE and not is_literal_none(expr)
3546-
and expr in type_map):
3546+
3547+
for i, expr in enumerate(node.operands):
3548+
var_type = operand_types[i]
3549+
other_type = operand_types[1 - i]
3550+
3551+
if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type):
35473552
# This should only be true at most once: there should be
3548-
# two elements in node.operands, and at least one of them
3549-
# should represent a None.
3550-
vartype = type_map[expr]
3551-
none_typ = [TypeRange(NoneType(), is_upper_bound=False)]
3552-
if_vars, else_vars = conditional_type_map(expr, vartype, none_typ)
3553+
# exactly two elements in node.operands and if the 'other type' is
3554+
# a singleton type, it by definition does not need to be narrowed:
3555+
# it already has the most precise type possible so does not need to
3556+
# be narrowed/included in the output map.
3557+
#
3558+
# TODO: Generalize this to handle the case where 'other_type' is
3559+
# a union of singleton types.
3560+
3561+
if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
3562+
fallback_name = other_type.fallback.type.fullname()
3563+
var_type = try_expanding_enum_to_union(var_type, fallback_name)
3564+
3565+
target_type = [TypeRange(other_type, is_upper_bound=False)]
3566+
if_vars, else_vars = conditional_type_map(expr, var_type, target_type)
35533567
break
35543568

35553569
if is_not:
@@ -4489,3 +4503,78 @@ def is_overlapping_types_no_promote(left: Type, right: Type) -> bool:
44894503
def is_private(node_name: str) -> bool:
44904504
"""Check if node is private to class definition."""
44914505
return node_name.startswith('__') and not node_name.endswith('__')
4506+
4507+
4508+
def is_singleton_type(typ: Type) -> bool:
4509+
"""Returns 'true' if this type is a "singleton type" -- if there exists
4510+
exactly only one runtime value associated with this type.
4511+
4512+
That is, given two values 'a' and 'b' that have the same type 't',
4513+
'is_singleton_type(t)' returns True if and only if the expression 'a is b' is
4514+
always true.
4515+
4516+
Currently, this returns True when given NoneTypes and enum LiteralTypes.
4517+
4518+
Note that other kinds of LiteralTypes cannot count as singleton types. For
4519+
example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed
4520+
that 'a is b' will always be true -- some implementations of Python will end up
4521+
constructing two distinct instances of 100001.
4522+
"""
4523+
# TODO: Also make this return True if the type is a bool LiteralType.
4524+
# Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented?
4525+
return isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal())
4526+
4527+
4528+
def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> Type:
4529+
"""Attempts to recursively expand any enum Instances with the given target_fullname
4530+
into a Union of all of its component LiteralTypes.
4531+
4532+
For example, if we have:
4533+
4534+
class Color(Enum):
4535+
RED = 1
4536+
BLUE = 2
4537+
YELLOW = 3
4538+
4539+
class Status(Enum):
4540+
SUCCESS = 1
4541+
FAILURE = 2
4542+
UNKNOWN = 3
4543+
4544+
...and if we call `try_expanding_enum_to_union(Union[Color, Status], 'module.Color')`,
4545+
this function will return Literal[Color.RED, Color.BLUE, Color.YELLOW, Status].
4546+
"""
4547+
if isinstance(typ, UnionType):
4548+
new_items = [try_expanding_enum_to_union(item, target_fullname)
4549+
for item in typ.items]
4550+
return UnionType.make_simplified_union(new_items)
4551+
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname:
4552+
new_items = []
4553+
for name, symbol in typ.type.names.items():
4554+
if not isinstance(symbol.node, Var):
4555+
continue
4556+
new_items.append(LiteralType(name, typ))
4557+
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
4558+
# insertion order only starting with Python 3.7. So, we sort these for older
4559+
# versions of Python to help make tests deterministic.
4560+
#
4561+
# We could probably skip the sort for Python 3.6 since people probably run mypy
4562+
# only using CPython, but we might as well for the sake of full correctness.
4563+
if sys.version_info < (3, 7):
4564+
new_items.sort(key=lambda lit: lit.value)
4565+
return UnionType.make_simplified_union(new_items)
4566+
else:
4567+
return typ
4568+
4569+
4570+
def coerce_to_literal(typ: Type) -> Type:
4571+
"""Recursively converts any Instances that have a last_known_value into the
4572+
corresponding LiteralType.
4573+
"""
4574+
if isinstance(typ, UnionType):
4575+
new_items = [coerce_to_literal(item) for item in typ.items]
4576+
return UnionType.make_simplified_union(new_items)
4577+
elif isinstance(typ, Instance) and typ.last_known_value:
4578+
return typ.last_known_value
4579+
else:
4580+
return typ

test-data/unit/check-enum.test

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,247 @@ class SomeEnum(Enum):
610610
main:2: note: Revealed type is 'builtins.int'
611611
[out2]
612612
main:2: note: Revealed type is 'builtins.str'
613+
614+
[case testEnumReachabilityChecksBasic]
615+
from enum import Enum
616+
from typing_extensions import Literal
617+
618+
class Foo(Enum):
619+
A = 1
620+
B = 2
621+
C = 3
622+
623+
x: Literal[Foo.A, Foo.B, Foo.C]
624+
if x is Foo.A:
625+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
626+
elif x is Foo.B:
627+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
628+
elif x is Foo.C:
629+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
630+
else:
631+
reveal_type(x) # No output here: this branch is unreachable
632+
633+
if Foo.A is x:
634+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
635+
elif Foo.B is x:
636+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
637+
elif Foo.C is x:
638+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
639+
else:
640+
reveal_type(x) # No output here: this branch is unreachable
641+
642+
y: Foo
643+
if y is Foo.A:
644+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
645+
elif y is Foo.B:
646+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
647+
elif y is Foo.C:
648+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
649+
else:
650+
reveal_type(y) # No output here: this branch is unreachable
651+
652+
if Foo.A is y:
653+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
654+
elif Foo.B is y:
655+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
656+
elif Foo.C is y:
657+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
658+
else:
659+
reveal_type(y) # No output here: this branch is unreachable
660+
[builtins fixtures/bool.pyi]
661+
662+
[case testEnumReachabilityChecksIndirect]
663+
from enum import Enum
664+
from typing_extensions import Literal, Final
665+
666+
class Foo(Enum):
667+
A = 1
668+
B = 2
669+
C = 3
670+
671+
def accepts_foo_a(x: Literal[Foo.A]) -> None: ...
672+
673+
x: Foo
674+
y: Literal[Foo.A]
675+
z: Final = Foo.A
676+
677+
if x is y:
678+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
679+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
680+
else:
681+
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
682+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
683+
if y is x:
684+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
685+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
686+
else:
687+
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
688+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
689+
690+
if x is z:
691+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
692+
reveal_type(z) # N: Revealed type is '__main__.Foo'
693+
accepts_foo_a(z)
694+
else:
695+
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
696+
reveal_type(z) # N: Revealed type is '__main__.Foo'
697+
accepts_foo_a(z)
698+
if z is x:
699+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
700+
reveal_type(z) # N: Revealed type is '__main__.Foo'
701+
accepts_foo_a(z)
702+
else:
703+
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
704+
reveal_type(z) # N: Revealed type is '__main__.Foo'
705+
accepts_foo_a(z)
706+
707+
if y is z:
708+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
709+
reveal_type(z) # N: Revealed type is '__main__.Foo'
710+
accepts_foo_a(z)
711+
else:
712+
reveal_type(y) # No output: this branch is unreachable
713+
reveal_type(z) # No output: this branch is unreachable
714+
if z is y:
715+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
716+
reveal_type(z) # N: Revealed type is '__main__.Foo'
717+
accepts_foo_a(z)
718+
else:
719+
reveal_type(y) # No output: this branch is unreachable
720+
reveal_type(z) # No output: this branch is unreachable
721+
[builtins fixtures/bool.pyi]
722+
723+
[case testEnumReachabilityNoNarrowingForUnionMessiness]
724+
from enum import Enum
725+
from typing_extensions import Literal
726+
727+
class Foo(Enum):
728+
A = 1
729+
B = 2
730+
C = 3
731+
732+
x: Foo
733+
y: Literal[Foo.A, Foo.B]
734+
z: Literal[Foo.B, Foo.C]
735+
736+
# For the sake of simplicity, no narrowing is done when the narrower type is a Union.
737+
if x is y:
738+
reveal_type(x) # N: Revealed type is '__main__.Foo'
739+
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
740+
else:
741+
reveal_type(x) # N: Revealed type is '__main__.Foo'
742+
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
743+
744+
if y is z:
745+
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
746+
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
747+
else:
748+
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
749+
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
750+
[builtins fixtures/bool.pyi]
751+
752+
[case testEnumReachabilityWithNone]
753+
# flags: --strict-optional
754+
from enum import Enum
755+
from typing import Optional
756+
757+
class Foo(Enum):
758+
A = 1
759+
B = 2
760+
C = 3
761+
762+
x: Optional[Foo]
763+
if x:
764+
reveal_type(x) # N: Revealed type is '__main__.Foo'
765+
else:
766+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
767+
768+
if x is not None:
769+
reveal_type(x) # N: Revealed type is '__main__.Foo'
770+
else:
771+
reveal_type(x) # N: Revealed type is 'None'
772+
773+
if x is Foo.A:
774+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
775+
else:
776+
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
777+
[builtins fixtures/bool.pyi]
778+
779+
[case testEnumReachabilityWithMultipleEnums]
780+
from enum import Enum
781+
from typing import Union
782+
from typing_extensions import Literal
783+
784+
class Foo(Enum):
785+
A = 1
786+
B = 2
787+
class Bar(Enum):
788+
A = 1
789+
B = 2
790+
791+
x1: Union[Foo, Bar]
792+
if x1 is Foo.A:
793+
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
794+
else:
795+
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
796+
797+
x2: Union[Foo, Bar]
798+
if x2 is Bar.A:
799+
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
800+
else:
801+
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
802+
803+
x3: Union[Foo, Bar]
804+
if x3 is Foo.A or x3 is Bar.A:
805+
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
806+
else:
807+
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
808+
809+
[builtins fixtures/bool.pyi]
810+
811+
[case testEnumReachabilityPEP484Example1]
812+
# flags: --strict-optional
813+
from typing import Union
814+
from typing_extensions import Final
815+
from enum import Enum
816+
817+
class Empty(Enum):
818+
token = 0
819+
_empty: Final = Empty.token
820+
821+
def func(x: Union[int, None, Empty] = _empty) -> int:
822+
boom = x + 42 # E: Unsupported left operand type for + ("None") \
823+
# E: Unsupported left operand type for + ("Empty") \
824+
# N: Left operand is of type "Union[int, None, Empty]"
825+
if x is _empty:
826+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
827+
return 0
828+
elif x is None:
829+
reveal_type(x) # N: Revealed type is 'None'
830+
return 1
831+
else: # At this point typechecker knows that x can only have type int
832+
reveal_type(x) # N: Revealed type is 'builtins.int'
833+
return x + 2
834+
[builtins fixtures/primitives.pyi]
835+
836+
[case testEnumReachabilityPEP484Example2]
837+
from typing import Union
838+
from enum import Enum
839+
840+
class Reason(Enum):
841+
timeout = 1
842+
error = 2
843+
844+
def process(response: Union[str, Reason] = '') -> str:
845+
if response is Reason.timeout:
846+
reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.timeout]'
847+
return 'TIMEOUT'
848+
elif response is Reason.error:
849+
reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.error]'
850+
return 'ERROR'
851+
else:
852+
# response can be only str, all other possible values exhausted
853+
reveal_type(response) # N: Revealed type is 'builtins.str'
854+
return 'PROCESSED: ' + response
855+
856+
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)