2
2
3
3
from collections import OrderedDict
4
4
from contextlib import contextmanager
5
+ import itertools
5
6
from typing import (
6
7
cast , Dict , Set , List , Tuple , Callable , Union , Optional , Sequence , Iterator
7
8
)
@@ -2554,15 +2555,18 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
2554
2555
if isinstance (index , SliceExpr ):
2555
2556
return self .visit_tuple_slice_helper (left_type , index )
2556
2557
2557
- n = self ._get_value (index )
2558
- if n is not None :
2559
- if n < 0 :
2560
- n += len (left_type .items )
2561
- if 0 <= n < len (left_type .items ):
2562
- return left_type .items [n ]
2563
- else :
2564
- self .chk .fail (message_registry .TUPLE_INDEX_OUT_OF_RANGE , e )
2565
- return AnyType (TypeOfAny .from_error )
2558
+ ns = self .try_getting_int_literals (index )
2559
+ if ns is not None :
2560
+ out = []
2561
+ for n in ns :
2562
+ if n < 0 :
2563
+ n += len (left_type .items )
2564
+ if 0 <= n < len (left_type .items ):
2565
+ out .append (left_type .items [n ])
2566
+ else :
2567
+ self .chk .fail (message_registry .TUPLE_INDEX_OUT_OF_RANGE , e )
2568
+ return AnyType (TypeOfAny .from_error )
2569
+ return UnionType .make_simplified_union (out )
2566
2570
else :
2567
2571
return self .nonliteral_tuple_index_helper (left_type , index )
2568
2572
elif isinstance (left_type , TypedDictType ):
@@ -2578,26 +2582,66 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
2578
2582
return result
2579
2583
2580
2584
def visit_tuple_slice_helper (self , left_type : TupleType , slic : SliceExpr ) -> Type :
2581
- begin = None
2582
- end = None
2583
- stride = None
2585
+ begin = [ None ] # type: Sequence[Optional[int]]
2586
+ end = [ None ] # type: Sequence[Optional[int]]
2587
+ stride = [ None ] # type: Sequence[Optional[int]]
2584
2588
2585
2589
if slic .begin_index :
2586
- begin = self ._get_value (slic .begin_index )
2587
- if begin is None :
2590
+ begin_raw = self .try_getting_int_literals (slic .begin_index )
2591
+ if begin_raw is None :
2588
2592
return self .nonliteral_tuple_index_helper (left_type , slic )
2593
+ begin = begin_raw
2589
2594
2590
2595
if slic .end_index :
2591
- end = self ._get_value (slic .end_index )
2592
- if end is None :
2596
+ end_raw = self .try_getting_int_literals (slic .end_index )
2597
+ if end_raw is None :
2593
2598
return self .nonliteral_tuple_index_helper (left_type , slic )
2599
+ end = end_raw
2594
2600
2595
2601
if slic .stride :
2596
- stride = self ._get_value (slic .stride )
2597
- if stride is None :
2602
+ stride_raw = self .try_getting_int_literals (slic .stride )
2603
+ if stride_raw is None :
2598
2604
return self .nonliteral_tuple_index_helper (left_type , slic )
2605
+ stride = stride_raw
2606
+
2607
+ items = [] # type: List[Type]
2608
+ for b , e , s in itertools .product (begin , end , stride ):
2609
+ items .append (left_type .slice (b , e , s ))
2610
+ return UnionType .make_simplified_union (items )
2599
2611
2600
- return left_type .slice (begin , stride , end )
2612
+ def try_getting_int_literals (self , index : Expression ) -> Optional [List [int ]]:
2613
+ """If the given expression or type corresponds to an int literal
2614
+ or a union of int literals, returns a list of the underlying ints.
2615
+ Otherwise, returns None.
2616
+
2617
+ Specifically, this function is guaranteed to return a list with
2618
+ one or more ints if one one the following is true:
2619
+
2620
+ 1. 'expr' is a IntExpr or a UnaryExpr backed by an IntExpr
2621
+ 2. 'typ' is a LiteralType containing an int
2622
+ 3. 'typ' is a UnionType containing only LiteralType of ints
2623
+ """
2624
+ if isinstance (index , IntExpr ):
2625
+ return [index .value ]
2626
+ elif isinstance (index , UnaryExpr ):
2627
+ if index .op == '-' :
2628
+ operand = index .expr
2629
+ if isinstance (operand , IntExpr ):
2630
+ return [- 1 * operand .value ]
2631
+ typ = self .accept (index )
2632
+ if isinstance (typ , Instance ) and typ .last_known_value is not None :
2633
+ typ = typ .last_known_value
2634
+ if isinstance (typ , LiteralType ) and isinstance (typ .value , int ):
2635
+ return [typ .value ]
2636
+ if isinstance (typ , UnionType ):
2637
+ out = []
2638
+ for item in typ .items :
2639
+ if isinstance (item , LiteralType ) and isinstance (item .value , int ):
2640
+ out .append (item .value )
2641
+ else :
2642
+ return None
2643
+ return out
2644
+ return None
2601
2645
2602
2646
def nonliteral_tuple_index_helper (self , left_type : TupleType , index : Expression ) -> Type :
2603
2647
index_type = self .accept (index )
@@ -2614,40 +2658,36 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)
2614
2658
else :
2615
2659
return union
2616
2660
2617
- def _get_value (self , index : Expression ) -> Optional [int ]:
2618
- if isinstance (index , IntExpr ):
2619
- return index .value
2620
- elif isinstance (index , UnaryExpr ):
2621
- if index .op == '-' :
2622
- operand = index .expr
2623
- if isinstance (operand , IntExpr ):
2624
- return - 1 * operand .value
2625
- typ = self .accept (index )
2626
- if isinstance (typ , Instance ) and typ .last_known_value is not None :
2627
- typ = typ .last_known_value
2628
- if isinstance (typ , LiteralType ) and isinstance (typ .value , int ):
2629
- return typ .value
2630
- return None
2631
-
2632
2661
def visit_typeddict_index_expr (self , td_type : TypedDictType , index : Expression ) -> Type :
2633
2662
if isinstance (index , (StrExpr , UnicodeExpr )):
2634
- item_name = index .value
2663
+ key_names = [ index .value ]
2635
2664
else :
2636
2665
typ = self .accept (index )
2637
- if isinstance (typ , Instance ) and typ .last_known_value is not None :
2638
- typ = typ .last_known_value
2639
-
2640
- if isinstance (typ , LiteralType ) and isinstance (typ .value , str ):
2641
- item_name = typ .value
2666
+ if isinstance (typ , UnionType ):
2667
+ key_types = typ .items
2642
2668
else :
2643
- self .msg .typeddict_key_must_be_string_literal (td_type , index )
2644
- return AnyType (TypeOfAny .from_error )
2669
+ key_types = [typ ]
2645
2670
2646
- item_type = td_type .items .get (item_name )
2647
- if item_type is None :
2648
- self .msg .typeddict_key_not_found (td_type , item_name , index )
2649
- return AnyType (TypeOfAny .from_error )
2650
- return item_type
2671
+ key_names = []
2672
+ for key_type in key_types :
2673
+ if isinstance (key_type , Instance ) and key_type .last_known_value is not None :
2674
+ key_type = key_type .last_known_value
2675
+
2676
+ if isinstance (key_type , LiteralType ) and isinstance (key_type .value , str ):
2677
+ key_names .append (key_type .value )
2678
+ else :
2679
+ self .msg .typeddict_key_must_be_string_literal (td_type , index )
2680
+ return AnyType (TypeOfAny .from_error )
2681
+
2682
+ value_types = []
2683
+ for key_name in key_names :
2684
+ value_type = td_type .items .get (key_name )
2685
+ if value_type is None :
2686
+ self .msg .typeddict_key_not_found (td_type , key_name , index )
2687
+ return AnyType (TypeOfAny .from_error )
2688
+ else :
2689
+ value_types .append (value_type )
2690
+ return UnionType .make_simplified_union (value_types )
2651
2691
2652
2692
def visit_enum_index_expr (self , enum_type : TypeInfo , index : Expression ,
2653
2693
context : Context ) -> Type :
0 commit comments