diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 42990d4d28a3..113076a380b4 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1721,6 +1721,7 @@ def infer_overload_return_type(self, matches: List[CallableType] = [] return_types: List[Type] = [] inferred_types: List[Type] = [] + self_contains_any = has_any_type(object_type) if object_type is not None else False args_contain_any = any(map(has_any_type, arg_types)) for typ in plausible_targets: @@ -1750,7 +1751,7 @@ def infer_overload_return_type(self, if is_match: # Return early if possible; otherwise record info so we can # check for ambiguity due to 'Any' below. - if not args_contain_any: + if not args_contain_any and not self_contains_any: return ret_type, infer_type matches.append(typ) return_types.append(ret_type) @@ -1759,7 +1760,9 @@ def infer_overload_return_type(self, if len(matches) == 0: # No match was found return None - elif any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names): + elif any_causes_overload_ambiguity( + matches, return_types, arg_types, arg_kinds, arg_names, + self_contains_any=self_contains_any): # An argument of type or containing the type 'Any' caused ambiguity. # We try returning a precise type if we can. If not, we give up and just return 'Any'. if all_same_types(return_types): @@ -4394,7 +4397,9 @@ def any_causes_overload_ambiguity(items: List[CallableType], return_types: List[Type], arg_types: List[Type], arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]]) -> bool: + arg_names: Optional[Sequence[Optional[str]]], + *, + self_contains_any: bool = False) -> bool: """May an argument containing 'Any' cause ambiguous result type on call to overloaded function? Note that this sometimes returns True even if there is no ambiguity, since a correct @@ -4440,7 +4445,7 @@ def any_causes_overload_ambiguity(items: List[CallableType], if not all_same_types(matching_formals) and not all_same_types(matching_returns): # Any maps to multiple different types, and the return types of these items differ. return True - return False + return self_contains_any def all_same_types(types: List[Type]) -> bool: diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index bf7acdc1cd51..2a066b9b9b32 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5335,7 +5335,68 @@ def register(cls: Type[_T]) -> int: ... def register(cls: Callable[..., _T]) -> str: ... def register(cls: Any) -> Any: return None - x = register(Foo) reveal_type(x) # N: Revealed type is "builtins.int" [builtins fixtures/dict.pyi] + +[case testOverloadSelfArgWithMultipleMatches] +# https://github.com/python/mypy/issues/11347 +from typing import Generic, TypeVar, overload, Any + +T = TypeVar('T') + +class Some(Generic[T]): + @overload + def method(self: Some[int]) -> str: ... + @overload + def method(self: Some[str]) -> float: ... + def method(self): ... + +s1: Some[int] +s2: Some[str] +s3: Some[Any] +reveal_type(s1.method()) # N: Revealed type is "builtins.str" +reveal_type(s2.method()) # N: Revealed type is "builtins.float" +reveal_type(s3.method()) # N: Revealed type is "Any" +[builtins fixtures/dict.pyi] + +[case testOverloadSelfArgWithOtherSameArgAndMultipleMatches] +from typing import Generic, TypeVar, overload, Any + +T = TypeVar('T') + +class Some(Generic[T]): + @overload + def method(self: Some[int], other: int) -> str: ... + @overload + def method(self: Some[str], other: int) -> float: ... + def method(self): ... + +# was ok +s1: Some[int] +s2: Some[str] +s3: Some[Any] +reveal_type(s1.method(1)) # N: Revealed type is "builtins.str" +reveal_type(s2.method(1)) # N: Revealed type is "builtins.float" +reveal_type(s3.method(1)) # N: Revealed type is "Any" +[builtins fixtures/dict.pyi] + +[case testOverloadSelfArgWithOtherDifferentArgAndMultipleMatches] +from typing import Generic, TypeVar, overload, Any + +T = TypeVar('T') + +class Some(Generic[T]): + @overload + def method(self: Some[int], other: int) -> str: ... + @overload + def method(self: Some[str], other: str) -> float: ... + def method(self): ... + +s1: Some[int] +s2: Some[str] +s3: Some[Any] +reveal_type(s1.method(1)) # N: Revealed type is "builtins.str" +reveal_type(s2.method('a')) # N: Revealed type is "builtins.float" +reveal_type(s3.method(1)) # N: Revealed type is "builtins.str" +[builtins fixtures/dict.pyi]