@@ -203,8 +203,9 @@ def visit_type_var(self, left: TypeVarType) -> bool:
203
203
def visit_callable_type (self , left : CallableType ) -> bool :
204
204
right = self .right
205
205
if isinstance (right , CallableType ):
206
- return is_callable_subtype (
206
+ return is_callable_compatible (
207
207
left , right ,
208
+ is_compat = is_subtype ,
208
209
ignore_pos_arg_names = self .ignore_pos_arg_names )
209
210
elif isinstance (right , Overloaded ):
210
211
return all (is_subtype (left , item , self .check_type_parameter ,
@@ -310,10 +311,12 @@ def visit_overloaded(self, left: Overloaded) -> bool:
310
311
else :
311
312
# If this one overlaps with the supertype in any way, but it wasn't
312
313
# an exact match, then it's a potential error.
313
- if (is_callable_subtype (left_item , right_item , ignore_return = True ,
314
- ignore_pos_arg_names = self .ignore_pos_arg_names ) or
315
- is_callable_subtype (right_item , left_item , ignore_return = True ,
316
- ignore_pos_arg_names = self .ignore_pos_arg_names )):
314
+ if (is_callable_compatible (left_item , right_item ,
315
+ is_compat = is_subtype , ignore_return = True ,
316
+ ignore_pos_arg_names = self .ignore_pos_arg_names ) or
317
+ is_callable_compatible (right_item , left_item ,
318
+ is_compat = is_subtype , ignore_return = True ,
319
+ ignore_pos_arg_names = self .ignore_pos_arg_names )):
317
320
# If this is an overload that's already been matched, there's no
318
321
# problem.
319
322
if left_item not in matched_overloads :
@@ -568,16 +571,54 @@ def non_method_protocol_members(tp: TypeInfo) -> List[str]:
568
571
return result
569
572
570
573
571
- def is_callable_subtype (left : CallableType , right : CallableType ,
572
- ignore_return : bool = False ,
573
- ignore_pos_arg_names : bool = False ,
574
- use_proper_subtype : bool = False ) -> bool :
575
- """Is left a subtype of right?"""
574
+ def is_callable_compatible (left : CallableType , right : CallableType ,
575
+ * ,
576
+ is_compat : Callable [[Type , Type ], bool ],
577
+ is_compat_return : Optional [Callable [[Type , Type ], bool ]] = None ,
578
+ ignore_return : bool = False ,
579
+ ignore_pos_arg_names : bool = False ,
580
+ check_args_covariantly : bool = False ) -> bool :
581
+ """Is the left compatible with the right, using the provided compatibility check?
576
582
577
- if use_proper_subtype :
578
- is_compat = is_proper_subtype
579
- else :
580
- is_compat = is_subtype
583
+ is_compat:
584
+ The check we want to run against the parameters.
585
+
586
+ is_compat_return:
587
+ The check we want to run against the return type.
588
+ If None, use the 'is_compat' check.
589
+
590
+ check_args_covariantly:
591
+ If true, check if the left's args is compatible with the right's
592
+ instead of the other way around (contravariantly).
593
+
594
+ This function is mostly used to check if the left is a subtype of the right which
595
+ is why the default is to check the args contravariantly. However, it's occasionally
596
+ useful to check the args using some other check, so we leave the variance
597
+ configurable.
598
+
599
+ For example, when checking the validity of overloads, it's useful to see if
600
+ the first overload alternative has more precise arguments then the second.
601
+ We would want to check the arguments covariantly in that case.
602
+
603
+ Note! The following two function calls are NOT equivalent:
604
+
605
+ is_callable_compatible(f, g, is_compat=is_subtype, check_args_covariantly=False)
606
+ is_callable_compatible(g, f, is_compat=is_subtype, check_args_covariantly=True)
607
+
608
+ The two calls are similar in that they both check the function arguments in
609
+ the same direction: they both run `is_subtype(argument_from_g, argument_from_f)`.
610
+
611
+ However, the two calls differ in which direction they check things likee
612
+ keyword arguments. For example, suppose f and g are defined like so:
613
+
614
+ def f(x: int, *y: int) -> int: ...
615
+ def g(x: int) -> int: ...
616
+
617
+ In this case, the first call will succeed and the second will fail: f is a
618
+ valid stand-in for g but not vice-versa.
619
+ """
620
+ if is_compat_return is None :
621
+ is_compat_return = is_compat
581
622
582
623
# If either function is implicitly typed, ignore positional arg names too
583
624
if left .implicit or right .implicit :
@@ -607,9 +648,12 @@ def is_callable_subtype(left: CallableType, right: CallableType,
607
648
left = unified
608
649
609
650
# Check return types.
610
- if not ignore_return and not is_compat (left .ret_type , right .ret_type ):
651
+ if not ignore_return and not is_compat_return (left .ret_type , right .ret_type ):
611
652
return False
612
653
654
+ if check_args_covariantly :
655
+ is_compat = flip_compat_check (is_compat )
656
+
613
657
if right .is_ellipsis_args :
614
658
return True
615
659
@@ -652,7 +696,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
652
696
# Right has an infinite series of optional positional arguments
653
697
# here. Get all further positional arguments of left, and make sure
654
698
# they're more general than their corresponding member in this
655
- # series. Also make sure left has its own inifite series of
699
+ # series. Also make sure left has its own infinite series of
656
700
# optional positional arguments.
657
701
if not left .is_var_arg :
658
702
return False
@@ -664,7 +708,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
664
708
right_by_position = right .argument_by_position (j )
665
709
assert right_by_position is not None
666
710
if not are_args_compatible (left_by_position , right_by_position ,
667
- ignore_pos_arg_names , use_proper_subtype ):
711
+ ignore_pos_arg_names , is_compat ):
668
712
return False
669
713
j += 1
670
714
continue
@@ -687,7 +731,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
687
731
right_by_name = right .argument_by_name (name )
688
732
assert right_by_name is not None
689
733
if not are_args_compatible (left_by_name , right_by_name ,
690
- ignore_pos_arg_names , use_proper_subtype ):
734
+ ignore_pos_arg_names , is_compat ):
691
735
return False
692
736
continue
693
737
@@ -696,7 +740,8 @@ def is_callable_subtype(left: CallableType, right: CallableType,
696
740
if left_arg is None :
697
741
return False
698
742
699
- if not are_args_compatible (left_arg , right_arg , ignore_pos_arg_names , use_proper_subtype ):
743
+ if not are_args_compatible (left_arg , right_arg ,
744
+ ignore_pos_arg_names , is_compat ):
700
745
return False
701
746
702
747
done_with_positional = False
@@ -748,7 +793,7 @@ def are_args_compatible(
748
793
left : FormalArgument ,
749
794
right : FormalArgument ,
750
795
ignore_pos_arg_names : bool ,
751
- use_proper_subtype : bool ) -> bool :
796
+ is_compat : Callable [[ Type , Type ], bool ] ) -> bool :
752
797
# If right has a specific name it wants this argument to be, left must
753
798
# have the same.
754
799
if right .name is not None and left .name != right .name :
@@ -759,18 +804,20 @@ def are_args_compatible(
759
804
if right .pos is not None and left .pos != right .pos :
760
805
return False
761
806
# Left must have a more general type
762
- if use_proper_subtype :
763
- if not is_proper_subtype (right .typ , left .typ ):
764
- return False
765
- else :
766
- if not is_subtype (right .typ , left .typ ):
767
- return False
807
+ if not is_compat (right .typ , left .typ ):
808
+ return False
768
809
# If right's argument is optional, left's must also be.
769
810
if not right .required and left .required :
770
811
return False
771
812
return True
772
813
773
814
815
+ def flip_compat_check (is_compat : Callable [[Type , Type ], bool ]) -> Callable [[Type , Type ], bool ]:
816
+ def new_is_compat (left : Type , right : Type ) -> bool :
817
+ return is_compat (right , left )
818
+ return new_is_compat
819
+
820
+
774
821
def unify_generic_callable (type : CallableType , target : CallableType ,
775
822
ignore_return : bool ) -> Optional [CallableType ]:
776
823
"""Try to unify a generic callable type with another callable type.
@@ -913,10 +960,7 @@ def visit_type_var(self, left: TypeVarType) -> bool:
913
960
def visit_callable_type (self , left : CallableType ) -> bool :
914
961
right = self .right
915
962
if isinstance (right , CallableType ):
916
- return is_callable_subtype (
917
- left , right ,
918
- ignore_pos_arg_names = False ,
919
- use_proper_subtype = True )
963
+ return is_callable_compatible (left , right , is_compat = is_proper_subtype )
920
964
elif isinstance (right , Overloaded ):
921
965
return all (is_proper_subtype (left , item )
922
966
for item in right .items ())
0 commit comments