|
5 | 5 | from typing import (
|
6 | 6 | cast, Dict, Set, List, Tuple, Callable, Union, Optional, Sequence, Iterator
|
7 | 7 | )
|
8 |
| -from typing_extensions import ClassVar, Final |
| 8 | +from typing_extensions import ClassVar, Final, overload |
9 | 9 |
|
10 | 10 | from mypy.errors import report_internal_error
|
11 | 11 | from mypy.typeanal import (
|
@@ -284,24 +284,27 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
|
284 | 284 | return self.msg.untyped_function_call(callee_type, e)
|
285 | 285 | # Figure out the full name of the callee for plugin lookup.
|
286 | 286 | object_type = None
|
287 |
| - if not isinstance(e.callee, RefExpr): |
288 |
| - fullname = None |
289 |
| - else: |
| 287 | + member = None |
| 288 | + fullname = None |
| 289 | + if isinstance(e.callee, RefExpr): |
| 290 | + # There are two special cases where plugins might act: |
| 291 | + # * A "static" reference/alias to a class or function; |
| 292 | + # get_function_hook() will be invoked for these. |
290 | 293 | fullname = e.callee.fullname
|
291 | 294 | if (isinstance(e.callee.node, TypeAlias) and
|
292 | 295 | isinstance(e.callee.node.target, Instance)):
|
293 | 296 | fullname = e.callee.node.target.type.fullname()
|
| 297 | + # * Call to a method on object that has a full name (see |
| 298 | + # method_fullname() for details on supported objects); |
| 299 | + # get_method_hook() and get_method_signature_hook() will |
| 300 | + # be invoked for these. |
294 | 301 | if (fullname is None
|
295 | 302 | and isinstance(e.callee, MemberExpr)
|
296 |
| - and e.callee.expr in self.chk.type_map |
297 |
| - and isinstance(callee_type, FunctionLike)): |
298 |
| - # For method calls we include the defining class for the method |
299 |
| - # in the full name (example: 'typing.Mapping.get'). |
300 |
| - callee_expr_type = self.chk.type_map[e.callee.expr] |
301 |
| - fullname = self.method_fullname(callee_expr_type, e.callee.name) |
302 |
| - if fullname is not None: |
303 |
| - object_type = callee_expr_type |
304 |
| - ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type) |
| 303 | + and e.callee.expr in self.chk.type_map): |
| 304 | + member = e.callee.name |
| 305 | + object_type = self.chk.type_map[e.callee.expr] |
| 306 | + ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, |
| 307 | + object_type, member) |
305 | 308 | if isinstance(e.callee, RefExpr) and len(e.args) == 2:
|
306 | 309 | if e.callee.fullname in ('builtins.isinstance', 'builtins.issubclass'):
|
307 | 310 | self.check_runtime_protocol_test(e)
|
@@ -632,21 +635,53 @@ def check_call_expr_with_callee_type(self,
|
632 | 635 | callee_type: Type,
|
633 | 636 | e: CallExpr,
|
634 | 637 | callable_name: Optional[str],
|
635 |
| - object_type: Optional[Type]) -> Type: |
| 638 | + object_type: Optional[Type], |
| 639 | + member: Optional[str] = None) -> Type: |
636 | 640 | """Type check call expression.
|
637 | 641 |
|
638 |
| - The given callee type overrides the type of the callee |
639 |
| - expression. |
640 |
| - """ |
641 |
| - # Try to refine the call signature using plugin hooks before checking the call. |
642 |
| - callee_type = self.transform_callee_type( |
643 |
| - callable_name, callee_type, e.args, e.arg_kinds, e, e.arg_names, object_type) |
| 642 | + The callee_type should be used as the type of callee expression. In particular, |
| 643 | + in case of a union type this can be a particular item of the union, so that we can |
| 644 | + apply plugin hooks to each item. |
644 | 645 |
|
| 646 | + The 'member', 'callable_name' and 'object_type' are only used to call plugin hooks. |
| 647 | + If 'callable_name' is None but 'member' is not None (member call), try constructing |
| 648 | + 'callable_name' using 'object_type' (the base type on which the method is called), |
| 649 | + for example 'typing.Mapping.get'. |
| 650 | + """ |
| 651 | + if callable_name is None and member is not None: |
| 652 | + assert object_type is not None |
| 653 | + callable_name = self.method_fullname(object_type, member) |
| 654 | + if callable_name: |
| 655 | + # Try to refine the call signature using plugin hooks before checking the call. |
| 656 | + callee_type = self.transform_callee_type( |
| 657 | + callable_name, callee_type, e.args, e.arg_kinds, e, e.arg_names, object_type) |
| 658 | + # Unions are special-cased to allow plugins to act on each item in the union. |
| 659 | + elif member is not None and isinstance(object_type, UnionType): |
| 660 | + return self.check_union_call_expr(e, object_type, member) |
645 | 661 | return self.check_call(callee_type, e.args, e.arg_kinds, e,
|
646 | 662 | e.arg_names, callable_node=e.callee,
|
647 | 663 | callable_name=callable_name,
|
648 | 664 | object_type=object_type)[0]
|
649 | 665 |
|
| 666 | + def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type: |
| 667 | + """"Type check calling a member expression where the base type is a union.""" |
| 668 | + res = [] # type: List[Type] |
| 669 | + for typ in object_type.relevant_items(): |
| 670 | + # Member access errors are already reported when visiting the member expression. |
| 671 | + self.msg.disable_errors() |
| 672 | + item = analyze_member_access(member, typ, e, False, False, False, |
| 673 | + self.msg, original_type=object_type, chk=self.chk, |
| 674 | + in_literal_context=self.is_literal_context()) |
| 675 | + self.msg.enable_errors() |
| 676 | + narrowed = self.narrow_type_from_binder(e.callee, item, skip_non_overlapping=True) |
| 677 | + if narrowed is None: |
| 678 | + continue |
| 679 | + callable_name = self.method_fullname(typ, member) |
| 680 | + item_object_type = typ if callable_name else None |
| 681 | + res.append(self.check_call_expr_with_callee_type(narrowed, e, callable_name, |
| 682 | + item_object_type)) |
| 683 | + return UnionType.make_simplified_union(res) |
| 684 | + |
650 | 685 | def check_call(self,
|
651 | 686 | callee: Type,
|
652 | 687 | args: List[Expression],
|
@@ -2018,13 +2053,48 @@ def check_method_call_by_name(self,
|
2018 | 2053 | """
|
2019 | 2054 | local_errors = local_errors or self.msg
|
2020 | 2055 | original_type = original_type or base_type
|
| 2056 | + # Unions are special-cased to allow plugins to act on each element of the union. |
| 2057 | + if isinstance(base_type, UnionType): |
| 2058 | + return self.check_union_method_call_by_name(method, base_type, |
| 2059 | + args, arg_kinds, |
| 2060 | + context, local_errors, original_type) |
| 2061 | + |
2021 | 2062 | method_type = analyze_member_access(method, base_type, context, False, False, True,
|
2022 | 2063 | local_errors, original_type=original_type,
|
2023 | 2064 | chk=self.chk,
|
2024 | 2065 | in_literal_context=self.is_literal_context())
|
2025 | 2066 | return self.check_method_call(
|
2026 | 2067 | method, base_type, method_type, args, arg_kinds, context, local_errors)
|
2027 | 2068 |
|
| 2069 | + def check_union_method_call_by_name(self, |
| 2070 | + method: str, |
| 2071 | + base_type: UnionType, |
| 2072 | + args: List[Expression], |
| 2073 | + arg_kinds: List[int], |
| 2074 | + context: Context, |
| 2075 | + local_errors: MessageBuilder, |
| 2076 | + original_type: Optional[Type] = None |
| 2077 | + ) -> Tuple[Type, Type]: |
| 2078 | + """Type check a call to a named method on an object with union type. |
| 2079 | +
|
| 2080 | + This essentially checks the call using check_method_call_by_name() for each |
| 2081 | + union item and unions the result. We do this to allow plugins to act on |
| 2082 | + individual union items. |
| 2083 | + """ |
| 2084 | + res = [] # type: List[Type] |
| 2085 | + meth_res = [] # type: List[Type] |
| 2086 | + for typ in base_type.relevant_items(): |
| 2087 | + # Format error messages consistently with |
| 2088 | + # mypy.checkmember.analyze_union_member_access(). |
| 2089 | + local_errors.disable_type_names += 1 |
| 2090 | + item, meth_item = self.check_method_call_by_name(method, typ, args, arg_kinds, |
| 2091 | + context, local_errors, |
| 2092 | + original_type) |
| 2093 | + local_errors.disable_type_names -= 1 |
| 2094 | + res.append(item) |
| 2095 | + meth_res.append(meth_item) |
| 2096 | + return UnionType.make_simplified_union(res), UnionType.make_simplified_union(meth_res) |
| 2097 | + |
2028 | 2098 | def check_method_call(self,
|
2029 | 2099 | method_name: str,
|
2030 | 2100 | base_type: Type,
|
@@ -3524,14 +3594,32 @@ def bool_type(self) -> Instance:
|
3524 | 3594 | """Return instance type 'bool'."""
|
3525 | 3595 | return self.named_type('builtins.bool')
|
3526 | 3596 |
|
3527 |
| - def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type: |
| 3597 | + @overload |
| 3598 | + def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type: ... |
| 3599 | + |
| 3600 | + @overload # noqa |
| 3601 | + def narrow_type_from_binder(self, expr: Expression, known_type: Type, |
| 3602 | + skip_non_overlapping: bool) -> Optional[Type]: ... |
| 3603 | + |
| 3604 | + def narrow_type_from_binder(self, expr: Expression, known_type: Type, # noqa |
| 3605 | + skip_non_overlapping: bool = False) -> Optional[Type]: |
| 3606 | + """Narrow down a known type of expression using information in conditional type binder. |
| 3607 | +
|
| 3608 | + If 'skip_non_overlapping' is True, return None if the type and restriction are |
| 3609 | + non-overlapping. |
| 3610 | + """ |
3528 | 3611 | if literal(expr) >= LITERAL_TYPE:
|
3529 | 3612 | restriction = self.chk.binder.get(expr)
|
3530 | 3613 | # If the current node is deferred, some variables may get Any types that they
|
3531 | 3614 | # otherwise wouldn't have. We don't want to narrow down these since it may
|
3532 | 3615 | # produce invalid inferred Optional[Any] types, at least.
|
3533 | 3616 | if restriction and not (isinstance(known_type, AnyType)
|
3534 | 3617 | and self.chk.current_node_deferred):
|
| 3618 | + # Note: this call should match the one in narrow_declared_type(). |
| 3619 | + if (skip_non_overlapping and |
| 3620 | + not is_overlapping_types(known_type, restriction, |
| 3621 | + prohibit_none_typevar_overlap=True)): |
| 3622 | + return None |
3535 | 3623 | ans = narrow_declared_type(known_type, restriction)
|
3536 | 3624 | return ans
|
3537 | 3625 | return known_type
|
|
0 commit comments