@@ -299,15 +299,23 @@ def callable_corresponding_argument(typ: CallableType,
299
299
return by_name if by_name is not None else by_pos
300
300
301
301
302
- def is_simple_literal (t : ProperType ) -> bool :
303
- """
304
- Whether a type is a simple enough literal to allow for fast Union simplification
302
+ def simple_literal_value_key (t : ProperType ) -> Optional [Tuple [str , ...]]:
303
+ """Return a hashable description of simple literal type.
304
+
305
+ Return None if not a simple literal type.
305
306
306
- For now this means enum or string
307
+ The return value can be used to simplify away duplicate types in
308
+ unions by comparing keys for equality. For now enum, string or
309
+ Instance with string last_known_value are supported.
307
310
"""
308
- return isinstance (t , LiteralType ) and (
309
- t .fallback .type .is_enum or t .fallback .type .fullname == 'builtins.str'
310
- )
311
+ if isinstance (t , LiteralType ):
312
+ if t .fallback .type .is_enum or t .fallback .type .fullname == 'builtins.str' :
313
+ assert isinstance (t .value , str )
314
+ return 'literal' , t .value , t .fallback .type .fullname
315
+ if isinstance (t , Instance ):
316
+ if t .last_known_value is not None and isinstance (t .last_known_value .value , str ):
317
+ return 'instance' , t .last_known_value .value , t .type .fullname
318
+ return None
311
319
312
320
313
321
def make_simplified_union (items : Sequence [Type ],
@@ -341,10 +349,20 @@ def make_simplified_union(items: Sequence[Type],
341
349
all_items .append (typ )
342
350
items = all_items
343
351
352
+ simplified_set = _remove_redundant_union_items (items , keep_erased )
353
+
354
+ # If more than one literal exists in the union, try to simplify
355
+ if (contract_literals and sum (isinstance (item , LiteralType ) for item in simplified_set ) > 1 ):
356
+ simplified_set = try_contracting_literals_in_union (simplified_set )
357
+
358
+ return UnionType .make_union (simplified_set , line , column )
359
+
360
+
361
+ def _remove_redundant_union_items (items : List [ProperType ], keep_erased : bool ) -> List [ProperType ]:
344
362
from mypy .subtypes import is_proper_subtype
345
363
346
364
removed : Set [int ] = set ()
347
- seen : Set [Tuple [str , str ]] = set ()
365
+ seen : Set [Tuple [str , ... ]] = set ()
348
366
349
367
# NB: having a separate fast path for Union of Literal and slow path for other things
350
368
# would arguably be cleaner, however it breaks down when simplifying the Union of two
@@ -354,10 +372,8 @@ def make_simplified_union(items: Sequence[Type],
354
372
if i in removed :
355
373
continue
356
374
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
357
- if is_simple_literal (item ):
358
- assert isinstance (item , LiteralType )
359
- assert isinstance (item .value , str )
360
- k = (item .value , item .fallback .type .fullname )
375
+ k = simple_literal_value_key (item )
376
+ if k is not None :
361
377
if k in seen :
362
378
removed .add (i )
363
379
continue
@@ -373,13 +389,13 @@ def make_simplified_union(items: Sequence[Type],
373
389
seen .add (k )
374
390
if safe_skip :
375
391
continue
392
+
376
393
# Keep track of the truishness info for deleted subtypes which can be relevant
377
394
cbt = cbf = False
378
395
for j , tj in enumerate (items ):
379
396
# NB: we don't need to check literals as the fast path above takes care of that
380
397
if (
381
398
i != j
382
- and not is_simple_literal (tj )
383
399
and is_proper_subtype (tj , item , keep_erased_types = keep_erased )
384
400
and is_redundant_literal_instance (item , tj ) # XXX?
385
401
):
@@ -393,13 +409,7 @@ def make_simplified_union(items: Sequence[Type],
393
409
elif not item .can_be_false and cbf :
394
410
items [i ] = true_or_false (item )
395
411
396
- simplified_set = [items [i ] for i in range (len (items )) if i not in removed ]
397
-
398
- # If more than one literal exists in the union, try to simplify
399
- if (contract_literals and sum (isinstance (item , LiteralType ) for item in simplified_set ) > 1 ):
400
- simplified_set = try_contracting_literals_in_union (simplified_set )
401
-
402
- return UnionType .make_union (simplified_set , line , column )
412
+ return [items [i ] for i in range (len (items )) if i not in removed ]
403
413
404
414
405
415
def _get_type_special_method_bool_ret_type (t : Type ) -> Optional [Type ]:
0 commit comments