Skip to content

Commit 4ff8d04

Browse files
authored
Fix union simplification performance regression (#12519)
#11962 can generate large unions with many Instance types with last_known_value set. This caused our union simplification algorithm to be extremely slow, as it hit an O(n**2) code path. We already had a fast code path for unions of regular literal types. This generalizes it for unions containing Instance types with last known values (which behave similarly to literals in a literal type context). Also fix a union simplification bug that I encountered while writing tests for this change. Work on #12408.
1 parent 0e8a03c commit 4ff8d04

File tree

3 files changed

+78
-20
lines changed

3 files changed

+78
-20
lines changed

mypy/test/testtypes.py

+39
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,12 @@ def test_simplified_union(self) -> None:
465465
self.assert_simplified_union([fx.a, UnionType([fx.a])], fx.a)
466466
self.assert_simplified_union([fx.b, UnionType([fx.c, UnionType([fx.d])])],
467467
UnionType([fx.b, fx.c, fx.d]))
468+
469+
def test_simplified_union_with_literals(self) -> None:
470+
fx = self.fx
471+
468472
self.assert_simplified_union([fx.lit1, fx.a], fx.a)
473+
self.assert_simplified_union([fx.lit1, fx.lit2, fx.a], fx.a)
469474
self.assert_simplified_union([fx.lit1, fx.lit1], fx.lit1)
470475
self.assert_simplified_union([fx.lit1, fx.lit2], UnionType([fx.lit1, fx.lit2]))
471476
self.assert_simplified_union([fx.lit1, fx.lit3], UnionType([fx.lit1, fx.lit3]))
@@ -481,6 +486,40 @@ def test_simplified_union(self) -> None:
481486
self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst]))
482487
self.assert_simplified_union([fx.lit1, fx.lit3_inst], UnionType([fx.lit1, fx.lit3_inst]))
483488

489+
def test_simplified_union_with_str_literals(self) -> None:
490+
fx = self.fx
491+
492+
self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.str_type], fx.str_type)
493+
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1], fx.lit_str1)
494+
self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3],
495+
UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3]))
496+
self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.uninhabited],
497+
UnionType([fx.lit_str1, fx.lit_str2]))
498+
499+
def test_simplified_union_with_str_instance_literals(self) -> None:
500+
fx = self.fx
501+
502+
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.str_type],
503+
fx.str_type)
504+
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str1_inst, fx.lit_str1_inst],
505+
fx.lit_str1_inst)
506+
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.lit_str3_inst],
507+
UnionType([fx.lit_str1_inst,
508+
fx.lit_str2_inst,
509+
fx.lit_str3_inst]))
510+
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.uninhabited],
511+
UnionType([fx.lit_str1_inst, fx.lit_str2_inst]))
512+
513+
def test_simplified_union_with_mixed_str_literals(self) -> None:
514+
fx = self.fx
515+
516+
self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst],
517+
UnionType([fx.lit_str1,
518+
fx.lit_str2,
519+
fx.lit_str3_inst]))
520+
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst],
521+
UnionType([fx.lit_str1, fx.lit_str1_inst]))
522+
484523
def assert_simplified_union(self, original: List[Type], union: Type) -> None:
485524
assert_equal(make_simplified_union(original), union)
486525
assert_equal(make_simplified_union(list(reversed(original))), union)

mypy/test/typefixture.py

+9
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
6565
variances=[COVARIANT]) # class tuple
6666
self.type_typei = self.make_type_info('builtins.type') # class type
6767
self.bool_type_info = self.make_type_info('builtins.bool')
68+
self.str_type_info = self.make_type_info('builtins.str')
6869
self.functioni = self.make_type_info('builtins.function') # function TODO
6970
self.ai = self.make_type_info('A', mro=[self.oi]) # class A
7071
self.bi = self.make_type_info('B', mro=[self.ai, self.oi]) # class B(A)
@@ -109,6 +110,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
109110
self.std_tuple = Instance(self.std_tuplei, [self.anyt]) # tuple
110111
self.type_type = Instance(self.type_typei, []) # type
111112
self.function = Instance(self.functioni, []) # function TODO
113+
self.str_type = Instance(self.str_type_info, [])
112114
self.a = Instance(self.ai, []) # A
113115
self.b = Instance(self.bi, []) # B
114116
self.c = Instance(self.ci, []) # C
@@ -163,6 +165,13 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
163165
self.lit3_inst = Instance(self.di, [], last_known_value=self.lit3)
164166
self.lit4_inst = Instance(self.ai, [], last_known_value=self.lit4)
165167

168+
self.lit_str1 = LiteralType("x", self.str_type)
169+
self.lit_str2 = LiteralType("y", self.str_type)
170+
self.lit_str3 = LiteralType("z", self.str_type)
171+
self.lit_str1_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str1)
172+
self.lit_str2_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str2)
173+
self.lit_str3_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str3)
174+
166175
self.type_a = TypeType.make_normalized(self.a)
167176
self.type_b = TypeType.make_normalized(self.b)
168177
self.type_c = TypeType.make_normalized(self.c)

mypy/typeops.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,23 @@ def callable_corresponding_argument(typ: CallableType,
299299
return by_name if by_name is not None else by_pos
300300

301301

302-
def is_simple_literal(t: ProperType) -> bool:
303-
"""
304-
Whether a type is a simple enough literal to allow for fast Union simplification
302+
def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]:
303+
"""Return a hashable description of simple literal type.
304+
305+
Return None if not a simple literal type.
305306
306-
For now this means enum or string
307+
The return value can be used to simplify away duplicate types in
308+
unions by comparing keys for equality. For now enum, string or
309+
Instance with string last_known_value are supported.
307310
"""
308-
return isinstance(t, LiteralType) and (
309-
t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str'
310-
)
311+
if isinstance(t, LiteralType):
312+
if t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str':
313+
assert isinstance(t.value, str)
314+
return 'literal', t.value, t.fallback.type.fullname
315+
if isinstance(t, Instance):
316+
if t.last_known_value is not None and isinstance(t.last_known_value.value, str):
317+
return 'instance', t.last_known_value.value, t.type.fullname
318+
return None
311319

312320

313321
def make_simplified_union(items: Sequence[Type],
@@ -341,10 +349,20 @@ def make_simplified_union(items: Sequence[Type],
341349
all_items.append(typ)
342350
items = all_items
343351

352+
simplified_set = _remove_redundant_union_items(items, keep_erased)
353+
354+
# If more than one literal exists in the union, try to simplify
355+
if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1):
356+
simplified_set = try_contracting_literals_in_union(simplified_set)
357+
358+
return UnionType.make_union(simplified_set, line, column)
359+
360+
361+
def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]:
344362
from mypy.subtypes import is_proper_subtype
345363

346364
removed: Set[int] = set()
347-
seen: Set[Tuple[str, str]] = set()
365+
seen: Set[Tuple[str, ...]] = set()
348366

349367
# NB: having a separate fast path for Union of Literal and slow path for other things
350368
# would arguably be cleaner, however it breaks down when simplifying the Union of two
@@ -354,10 +372,8 @@ def make_simplified_union(items: Sequence[Type],
354372
if i in removed:
355373
continue
356374
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
357-
if is_simple_literal(item):
358-
assert isinstance(item, LiteralType)
359-
assert isinstance(item.value, str)
360-
k = (item.value, item.fallback.type.fullname)
375+
k = simple_literal_value_key(item)
376+
if k is not None:
361377
if k in seen:
362378
removed.add(i)
363379
continue
@@ -373,13 +389,13 @@ def make_simplified_union(items: Sequence[Type],
373389
seen.add(k)
374390
if safe_skip:
375391
continue
392+
376393
# Keep track of the truishness info for deleted subtypes which can be relevant
377394
cbt = cbf = False
378395
for j, tj in enumerate(items):
379396
# NB: we don't need to check literals as the fast path above takes care of that
380397
if (
381398
i != j
382-
and not is_simple_literal(tj)
383399
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
384400
and is_redundant_literal_instance(item, tj) # XXX?
385401
):
@@ -393,13 +409,7 @@ def make_simplified_union(items: Sequence[Type],
393409
elif not item.can_be_false and cbf:
394410
items[i] = true_or_false(item)
395411

396-
simplified_set = [items[i] for i in range(len(items)) if i not in removed]
397-
398-
# If more than one literal exists in the union, try to simplify
399-
if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1):
400-
simplified_set = try_contracting_literals_in_union(simplified_set)
401-
402-
return UnionType.make_union(simplified_set, line, column)
412+
return [items[i] for i in range(len(items)) if i not in removed]
403413

404414

405415
def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]:

0 commit comments

Comments
 (0)