@@ -177,6 +177,9 @@ class ExpressionChecker(ExpressionVisitor[Type]):
177
177
# Type context for type inference
178
178
type_context : List [Optional [Type ]]
179
179
180
+ # cache resolved types in some cases
181
+ resolved_type : Dict [Expression , ProperType ]
182
+
180
183
strfrm_checker : StringFormatterChecker
181
184
plugin : Plugin
182
185
@@ -197,6 +200,11 @@ def __init__(self,
197
200
self .type_overrides : Dict [Expression , Type ] = {}
198
201
self .strfrm_checker = StringFormatterChecker (self , self .chk , self .msg )
199
202
203
+ self .resolved_type = {}
204
+
205
+ def reset (self ) -> None :
206
+ self .resolved_type = {}
207
+
200
208
def visit_name_expr (self , e : NameExpr ) -> Type :
201
209
"""Type check a name expression.
202
210
@@ -3269,13 +3277,13 @@ def apply_type_arguments_to_callable(
3269
3277
3270
3278
def visit_list_expr (self , e : ListExpr ) -> Type :
3271
3279
"""Type check a list expression [...]."""
3272
- return self .check_lst_expr (e . items , 'builtins.list' , '<list>' , e )
3280
+ return self .check_lst_expr (e , 'builtins.list' , '<list>' )
3273
3281
3274
3282
def visit_set_expr (self , e : SetExpr ) -> Type :
3275
- return self .check_lst_expr (e . items , 'builtins.set' , '<set>' , e )
3283
+ return self .check_lst_expr (e , 'builtins.set' , '<set>' )
3276
3284
3277
3285
def fast_container_type (
3278
- self , items : List [ Expression ], container_fullname : str
3286
+ self , e : Union [ ListExpr , SetExpr , TupleExpr ], container_fullname : str
3279
3287
) -> Optional [Type ]:
3280
3288
"""
3281
3289
Fast path to determine the type of a list or set literal,
@@ -3290,21 +3298,28 @@ def fast_container_type(
3290
3298
ctx = self .type_context [- 1 ]
3291
3299
if ctx :
3292
3300
return None
3301
+ rt = self .resolved_type .get (e , None )
3302
+ if rt is not None :
3303
+ return rt if isinstance (rt , Instance ) else None
3293
3304
values : List [Type ] = []
3294
- for item in items :
3305
+ for item in e . items :
3295
3306
if isinstance (item , StarExpr ):
3296
3307
# fallback to slow path
3308
+ self .resolved_type [e ] = NoneType ()
3297
3309
return None
3298
3310
values .append (self .accept (item ))
3299
3311
vt = join .join_type_list (values )
3300
3312
if not allow_fast_container_literal (vt ):
3313
+ self .resolved_type [e ] = NoneType ()
3301
3314
return None
3302
- return self .chk .named_generic_type (container_fullname , [vt ])
3315
+ ct = self .chk .named_generic_type (container_fullname , [vt ])
3316
+ self .resolved_type [e ] = ct
3317
+ return ct
3303
3318
3304
- def check_lst_expr (self , items : List [ Expression ], fullname : str ,
3305
- tag : str , context : Context ) -> Type :
3319
+ def check_lst_expr (self , e : Union [ ListExpr , SetExpr , TupleExpr ], fullname : str ,
3320
+ tag : str ) -> Type :
3306
3321
# fast path
3307
- t = self .fast_container_type (items , fullname )
3322
+ t = self .fast_container_type (e , fullname )
3308
3323
if t :
3309
3324
return t
3310
3325
@@ -3323,10 +3338,10 @@ def check_lst_expr(self, items: List[Expression], fullname: str,
3323
3338
variables = [tv ])
3324
3339
out = self .check_call (constructor ,
3325
3340
[(i .expr if isinstance (i , StarExpr ) else i )
3326
- for i in items ],
3341
+ for i in e . items ],
3327
3342
[(nodes .ARG_STAR if isinstance (i , StarExpr ) else nodes .ARG_POS )
3328
- for i in items ],
3329
- context )[0 ]
3343
+ for i in e . items ],
3344
+ e )[0 ]
3330
3345
return remove_instance_last_known_values (out )
3331
3346
3332
3347
def visit_tuple_expr (self , e : TupleExpr ) -> Type :
@@ -3376,7 +3391,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
3376
3391
else :
3377
3392
# A star expression that's not a Tuple.
3378
3393
# Treat the whole thing as a variable-length tuple.
3379
- return self .check_lst_expr (e . items , 'builtins.tuple' , '<tuple>' , e )
3394
+ return self .check_lst_expr (e , 'builtins.tuple' , '<tuple>' )
3380
3395
else :
3381
3396
if not type_context_items or j >= len (type_context_items ):
3382
3397
tt = self .accept (item )
@@ -3402,6 +3417,9 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]:
3402
3417
ctx = self .type_context [- 1 ]
3403
3418
if ctx :
3404
3419
return None
3420
+ rt = self .resolved_type .get (e , None )
3421
+ if rt is not None :
3422
+ return rt if isinstance (rt , Instance ) else None
3405
3423
keys : List [Type ] = []
3406
3424
values : List [Type ] = []
3407
3425
stargs : Optional [Tuple [Type , Type ]] = None
@@ -3415,17 +3433,22 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]:
3415
3433
):
3416
3434
stargs = (st .args [0 ], st .args [1 ])
3417
3435
else :
3436
+ self .resolved_type [e ] = NoneType ()
3418
3437
return None
3419
3438
else :
3420
3439
keys .append (self .accept (key ))
3421
3440
values .append (self .accept (value ))
3422
3441
kt = join .join_type_list (keys )
3423
3442
vt = join .join_type_list (values )
3424
3443
if not (allow_fast_container_literal (kt ) and allow_fast_container_literal (vt )):
3444
+ self .resolved_type [e ] = NoneType ()
3425
3445
return None
3426
3446
if stargs and (stargs [0 ] != kt or stargs [1 ] != vt ):
3447
+ self .resolved_type [e ] = NoneType ()
3427
3448
return None
3428
- return self .chk .named_generic_type ('builtins.dict' , [kt , vt ])
3449
+ dt = self .chk .named_generic_type ('builtins.dict' , [kt , vt ])
3450
+ self .resolved_type [e ] = dt
3451
+ return dt
3429
3452
3430
3453
def visit_dict_expr (self , e : DictExpr ) -> Type :
3431
3454
"""Type check a dict expression.
0 commit comments