Skip to content

Commit a503132

Browse files
authored
Recombine complete union of enum literals into original type (#9063) (#9097)
Closes #9063
1 parent f0a2c9f commit a503132

File tree

3 files changed

+76
-6
lines changed

3 files changed

+76
-6
lines changed

mypy/join.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ def visit_literal_type(self, t: LiteralType) -> ProperType:
310310
if isinstance(self.s, LiteralType):
311311
if t == self.s:
312312
return t
313-
else:
314-
return join_types(self.s.fallback, t.fallback)
313+
if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
314+
return mypy.typeops.make_simplified_union([self.s, t])
315+
return join_types(self.s.fallback, t.fallback)
315316
else:
316317
return join_types(self.s, t.fallback)
317318

mypy/typeops.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
since these may assume that MROs are ready.
66
"""
77

8-
from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar
8+
from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any
99
from typing_extensions import Type as TypingType
10+
import itertools
1011
import sys
1112

1213
from mypy.types import (
@@ -315,7 +316,8 @@ def callable_corresponding_argument(typ: CallableType,
315316

316317
def make_simplified_union(items: Sequence[Type],
317318
line: int = -1, column: int = -1,
318-
*, keep_erased: bool = False) -> ProperType:
319+
*, keep_erased: bool = False,
320+
contract_literals: bool = True) -> ProperType:
319321
"""Build union type with redundant union items removed.
320322
321323
If only a single item remains, this may return a non-union type.
@@ -377,6 +379,11 @@ def make_simplified_union(items: Sequence[Type],
377379
items[i] = true_or_false(ti)
378380

379381
simplified_set = [items[i] for i in range(len(items)) if i not in removed]
382+
383+
# If more than one literal exists in the union, try to simplify
384+
if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1):
385+
simplified_set = try_contracting_literals_in_union(simplified_set)
386+
380387
return UnionType.make_union(simplified_set, line, column)
381388

382389

@@ -684,7 +691,7 @@ class Status(Enum):
684691

685692
if isinstance(typ, UnionType):
686693
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
687-
return make_simplified_union(items)
694+
return make_simplified_union(items, contract_literals=False)
688695
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname:
689696
new_items = []
690697
for name, symbol in typ.type.names.items():
@@ -702,11 +709,39 @@ class Status(Enum):
702709
# only using CPython, but we might as well for the sake of full correctness.
703710
if sys.version_info < (3, 7):
704711
new_items.sort(key=lambda lit: lit.value)
705-
return make_simplified_union(new_items)
712+
return make_simplified_union(new_items, contract_literals=False)
706713
else:
707714
return typ
708715

709716

717+
def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperType]:
718+
"""Contracts any literal types back into a sum type if possible.
719+
720+
Will replace the first instance of the literal with the sum type and
721+
remove all others.
722+
723+
if we call `try_contracting_union(Literal[Color.RED, Color.BLUE, Color.YELLOW])`,
724+
this function will return Color.
725+
"""
726+
sum_types = {} # type: Dict[str, Tuple[Set[Any], List[int]]]
727+
marked_for_deletion = set()
728+
for idx, typ in enumerate(types):
729+
if isinstance(typ, LiteralType):
730+
fullname = typ.fallback.type.fullname
731+
if typ.fallback.type.is_enum:
732+
if fullname not in sum_types:
733+
sum_types[fullname] = (set(get_enum_values(typ.fallback)), [])
734+
literals, indexes = sum_types[fullname]
735+
literals.discard(typ.value)
736+
indexes.append(idx)
737+
if not literals:
738+
first, *rest = indexes
739+
types[first] = typ.fallback
740+
marked_for_deletion |= set(rest)
741+
return list(itertools.compress(types, [(i not in marked_for_deletion)
742+
for i in range(len(types))]))
743+
744+
710745
def coerce_to_literal(typ: Type) -> Type:
711746
"""Recursively converts any Instances that have a last_known_value or are
712747
instances of enum types with a single value into the corresponding LiteralType.

test-data/unit/check-enum.test

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ elif x is Foo.C:
713713
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
714714
else:
715715
reveal_type(x) # No output here: this branch is unreachable
716+
reveal_type(x) # N: Revealed type is '__main__.Foo'
716717

717718
if Foo.A is x:
718719
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -722,6 +723,7 @@ elif Foo.C is x:
722723
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
723724
else:
724725
reveal_type(x) # No output here: this branch is unreachable
726+
reveal_type(x) # N: Revealed type is '__main__.Foo'
725727

726728
y: Foo
727729
if y is Foo.A:
@@ -732,6 +734,7 @@ elif y is Foo.C:
732734
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
733735
else:
734736
reveal_type(y) # No output here: this branch is unreachable
737+
reveal_type(y) # N: Revealed type is '__main__.Foo'
735738

736739
if Foo.A is y:
737740
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -741,6 +744,7 @@ elif Foo.C is y:
741744
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
742745
else:
743746
reveal_type(y) # No output here: this branch is unreachable
747+
reveal_type(y) # N: Revealed type is '__main__.Foo'
744748
[builtins fixtures/bool.pyi]
745749

746750
[case testEnumReachabilityChecksWithOrdering]
@@ -815,12 +819,14 @@ if x is y:
815819
else:
816820
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
817821
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
822+
reveal_type(x) # N: Revealed type is '__main__.Foo'
818823
if y is x:
819824
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
820825
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
821826
else:
822827
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
823828
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
829+
reveal_type(x) # N: Revealed type is '__main__.Foo'
824830

825831
if x is z:
826832
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -830,6 +836,7 @@ else:
830836
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
831837
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
832838
accepts_foo_a(z)
839+
reveal_type(x) # N: Revealed type is '__main__.Foo'
833840
if z is x:
834841
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
835842
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
@@ -838,6 +845,7 @@ else:
838845
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
839846
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
840847
accepts_foo_a(z)
848+
reveal_type(x) # N: Revealed type is '__main__.Foo'
841849

842850
if y is z:
843851
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -909,6 +917,7 @@ if x is Foo.A:
909917
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
910918
else:
911919
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
920+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
912921
[builtins fixtures/bool.pyi]
913922

914923
[case testEnumReachabilityWithMultipleEnums]
@@ -928,18 +937,21 @@ if x1 is Foo.A:
928937
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
929938
else:
930939
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
940+
reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
931941

932942
x2: Union[Foo, Bar]
933943
if x2 is Bar.A:
934944
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
935945
else:
936946
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
947+
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
937948

938949
x3: Union[Foo, Bar]
939950
if x3 is Foo.A or x3 is Bar.A:
940951
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
941952
else:
942953
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
954+
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
943955

944956
[builtins fixtures/bool.pyi]
945957

@@ -1299,3 +1311,25 @@ reveal_type(a._value_) # N: Revealed type is 'Any'
12991311
[builtins fixtures/__new__.pyi]
13001312
[builtins fixtures/primitives.pyi]
13011313
[typing fixtures/typing-medium.pyi]
1314+
1315+
[case testEnumNarrowedToTwoLiterals]
1316+
# Regression test: two literals of an enum would be joined
1317+
# as the full type, regardless of the amount of elements
1318+
# the enum contains.
1319+
from enum import Enum
1320+
from typing import Union
1321+
from typing_extensions import Literal
1322+
1323+
class Foo(Enum):
1324+
A = 1
1325+
B = 2
1326+
C = 3
1327+
1328+
def f(x: Foo):
1329+
if x is Foo.A:
1330+
return x
1331+
if x is Foo.B:
1332+
pass
1333+
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
1334+
1335+
[builtins fixtures/bool.pyi]

0 commit comments

Comments
 (0)