Skip to content

Commit 8deeaf3

Browse files
authored
Enable generic NamedTuples (#13396)
Fixes #685 This builds on top of some infra I added for recursive types (Ref #13297). Implementation is based on the idea in #13297 (comment). Generally it works well, but there are actually some problems for named tuples that are recursive. Special-casing them in `maptype.py` is a bit ugly, but I think this is best we can get at the moment.
1 parent fd7040e commit 8deeaf3

17 files changed

+399
-38
lines changed

mypy/checkexpr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3667,6 +3667,9 @@ def visit_type_application(self, tapp: TypeApplication) -> Type:
36673667
if isinstance(item, Instance):
36683668
tp = type_object_type(item.type, self.named_type)
36693669
return self.apply_type_arguments_to_callable(tp, item.args, tapp)
3670+
elif isinstance(item, TupleType) and item.partial_fallback.type.is_named_tuple:
3671+
tp = type_object_type(item.partial_fallback.type, self.named_type)
3672+
return self.apply_type_arguments_to_callable(tp, item.partial_fallback.args, tapp)
36703673
else:
36713674
self.chk.fail(message_registry.ONLY_CLASS_APPLICATION, tapp)
36723675
return AnyType(TypeOfAny.from_error)

mypy/constraints.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,14 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
882882
]
883883

884884
if isinstance(actual, TupleType) and len(actual.items) == len(template.items):
885+
if (
886+
actual.partial_fallback.type.is_named_tuple
887+
and template.partial_fallback.type.is_named_tuple
888+
):
889+
# For named tuples using just the fallbacks usually gives better results.
890+
return infer_constraints(
891+
template.partial_fallback, actual.partial_fallback, self.direction
892+
)
885893
res: List[Constraint] = []
886894
for i in range(len(template.items)):
887895
res.extend(infer_constraints(template.items[i], actual.items[i], self.direction))

mypy/expandtype.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,11 @@ def expand_types_with_unpack(
298298
def visit_tuple_type(self, t: TupleType) -> Type:
299299
items = self.expand_types_with_unpack(t.items)
300300
if isinstance(items, list):
301-
return t.copy_modified(items=items)
301+
fallback = t.partial_fallback.accept(self)
302+
fallback = get_proper_type(fallback)
303+
if not isinstance(fallback, Instance):
304+
fallback = t.partial_fallback
305+
return t.copy_modified(items=items, fallback=fallback)
302306
else:
303307
return items
304308

mypy/maptype.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,20 @@
22

33
from typing import Dict, List
44

5+
import mypy.typeops
56
from mypy.expandtype import expand_type
67
from mypy.nodes import TypeInfo
7-
from mypy.types import AnyType, Instance, ProperType, Type, TypeOfAny, TypeVarId
8+
from mypy.types import (
9+
AnyType,
10+
Instance,
11+
ProperType,
12+
TupleType,
13+
Type,
14+
TypeOfAny,
15+
TypeVarId,
16+
get_proper_type,
17+
has_type_vars,
18+
)
819

920

1021
def map_instance_to_supertype(instance: Instance, superclass: TypeInfo) -> Instance:
@@ -18,6 +29,20 @@ def map_instance_to_supertype(instance: Instance, superclass: TypeInfo) -> Insta
1829
# Fast path: `instance` already belongs to `superclass`.
1930
return instance
2031

32+
if superclass.fullname == "builtins.tuple" and instance.type.tuple_type:
33+
if has_type_vars(instance.type.tuple_type):
34+
# We special case mapping generic tuple types to tuple base, because for
35+
# such tuples fallback can't be calculated before applying type arguments.
36+
alias = instance.type.special_alias
37+
assert alias is not None
38+
if not alias._is_recursive:
39+
# Unfortunately we can't support this for generic recursive tuples.
40+
# If we skip this special casing we will fall back to tuple[Any, ...].
41+
env = instance_to_type_environment(instance)
42+
tuple_type = get_proper_type(expand_type(instance.type.tuple_type, env))
43+
if isinstance(tuple_type, TupleType):
44+
return mypy.typeops.tuple_fallback(tuple_type)
45+
2146
if not superclass.type_vars:
2247
# Fast path: `superclass` has no type variables to map to.
2348
return Instance(superclass, [])

mypy/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3294,7 +3294,7 @@ def from_tuple_type(cls, info: TypeInfo) -> TypeAlias:
32943294
"""Generate an alias to the tuple type described by a given TypeInfo."""
32953295
assert info.tuple_type
32963296
return TypeAlias(
3297-
info.tuple_type.copy_modified(fallback=mypy.types.Instance(info, [])),
3297+
info.tuple_type.copy_modified(fallback=mypy.types.Instance(info, info.defn.type_vars)),
32983298
info.fullname,
32993299
info.line,
33003300
info.column,

mypy/semanal.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,16 +1392,12 @@ def analyze_class(self, defn: ClassDef) -> None:
13921392
if self.analyze_typeddict_classdef(defn):
13931393
return
13941394

1395-
if self.analyze_namedtuple_classdef(defn):
1395+
if self.analyze_namedtuple_classdef(defn, tvar_defs):
13961396
return
13971397

13981398
# Create TypeInfo for class now that base classes and the MRO can be calculated.
13991399
self.prepare_class_def(defn)
1400-
1401-
defn.type_vars = tvar_defs
1402-
defn.info.type_vars = []
1403-
# we want to make sure any additional logic in add_type_vars gets run
1404-
defn.info.add_type_vars()
1400+
self.setup_type_vars(defn, tvar_defs)
14051401
if base_error:
14061402
defn.info.fallback_to_any = True
14071403

@@ -1414,6 +1410,19 @@ def analyze_class(self, defn: ClassDef) -> None:
14141410
self.analyze_class_decorator(defn, decorator)
14151411
self.analyze_class_body_common(defn)
14161412

1413+
def setup_type_vars(self, defn: ClassDef, tvar_defs: List[TypeVarLikeType]) -> None:
1414+
defn.type_vars = tvar_defs
1415+
defn.info.type_vars = []
1416+
# we want to make sure any additional logic in add_type_vars gets run
1417+
defn.info.add_type_vars()
1418+
1419+
def setup_alias_type_vars(self, defn: ClassDef) -> None:
1420+
assert defn.info.special_alias is not None
1421+
defn.info.special_alias.alias_tvars = list(defn.info.type_vars)
1422+
target = defn.info.special_alias.target
1423+
assert isinstance(target, ProperType) and isinstance(target, TupleType)
1424+
target.partial_fallback.args = tuple(defn.type_vars)
1425+
14171426
def is_core_builtin_class(self, defn: ClassDef) -> bool:
14181427
return self.cur_mod_id == "builtins" and defn.name in CORE_BUILTIN_CLASSES
14191428

@@ -1446,7 +1455,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool:
14461455
return True
14471456
return False
14481457

1449-
def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
1458+
def analyze_namedtuple_classdef(
1459+
self, defn: ClassDef, tvar_defs: List[TypeVarLikeType]
1460+
) -> bool:
14501461
"""Check if this class can define a named tuple."""
14511462
if (
14521463
defn.info
@@ -1465,7 +1476,9 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
14651476
if info is None:
14661477
self.mark_incomplete(defn.name, defn)
14671478
else:
1468-
self.prepare_class_def(defn, info)
1479+
self.prepare_class_def(defn, info, custom_names=True)
1480+
self.setup_type_vars(defn, tvar_defs)
1481+
self.setup_alias_type_vars(defn)
14691482
with self.scope.class_scope(defn.info):
14701483
with self.named_tuple_analyzer.save_namedtuple_body(info):
14711484
self.analyze_class_body_common(defn)
@@ -1690,7 +1703,31 @@ def get_all_bases_tvars(
16901703
tvars.extend(base_tvars)
16911704
return remove_dups(tvars)
16921705

1693-
def prepare_class_def(self, defn: ClassDef, info: Optional[TypeInfo] = None) -> None:
1706+
def get_and_bind_all_tvars(self, type_exprs: List[Expression]) -> List[TypeVarLikeType]:
1707+
"""Return all type variable references in item type expressions.
1708+
This is a helper for generic TypedDicts and NamedTuples. Essentially it is
1709+
a simplified version of the logic we use for ClassDef bases. We duplicate
1710+
some amount of code, because it is hard to refactor common pieces.
1711+
"""
1712+
tvars = []
1713+
for base_expr in type_exprs:
1714+
try:
1715+
base = self.expr_to_unanalyzed_type(base_expr)
1716+
except TypeTranslationError:
1717+
# This error will be caught later.
1718+
continue
1719+
base_tvars = base.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope))
1720+
tvars.extend(base_tvars)
1721+
tvars = remove_dups(tvars) # Variables are defined in order of textual appearance.
1722+
tvar_defs = []
1723+
for name, tvar_expr in tvars:
1724+
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
1725+
tvar_defs.append(tvar_def)
1726+
return tvar_defs
1727+
1728+
def prepare_class_def(
1729+
self, defn: ClassDef, info: Optional[TypeInfo] = None, custom_names: bool = False
1730+
) -> None:
16941731
"""Prepare for the analysis of a class definition.
16951732
16961733
Create an empty TypeInfo and store it in a symbol table, or if the 'info'
@@ -1702,10 +1739,13 @@ def prepare_class_def(self, defn: ClassDef, info: Optional[TypeInfo] = None) ->
17021739
info = info or self.make_empty_type_info(defn)
17031740
defn.info = info
17041741
info.defn = defn
1705-
if not self.is_func_scope():
1706-
info._fullname = self.qualified_name(defn.name)
1707-
else:
1708-
info._fullname = info.name
1742+
if not custom_names:
1743+
# Some special classes (in particular NamedTuples) use custom fullname logic.
1744+
# Don't override it here (also see comment below, this needs cleanup).
1745+
if not self.is_func_scope():
1746+
info._fullname = self.qualified_name(defn.name)
1747+
else:
1748+
info._fullname = info.name
17091749
local_name = defn.name
17101750
if "@" in local_name:
17111751
local_name = local_name.split("@")[0]
@@ -1866,6 +1906,7 @@ def configure_tuple_base_class(self, defn: ClassDef, base: TupleType) -> Instanc
18661906
if info.special_alias and has_placeholder(info.special_alias.target):
18671907
self.defer(force_progress=True)
18681908
info.update_tuple_type(base)
1909+
self.setup_alias_type_vars(defn)
18691910

18701911
if base.partial_fallback.type.fullname == "builtins.tuple" and not has_placeholder(base):
18711912
# Fallback can only be safely calculated after semantic analysis, since base
@@ -2658,7 +2699,7 @@ def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
26582699
return False
26592700
lvalue = s.lvalues[0]
26602701
name = lvalue.name
2661-
internal_name, info = self.named_tuple_analyzer.check_namedtuple(
2702+
internal_name, info, tvar_defs = self.named_tuple_analyzer.check_namedtuple(
26622703
s.rvalue, name, self.is_func_scope()
26632704
)
26642705
if internal_name is None:
@@ -2678,6 +2719,9 @@ def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
26782719
# Yes, it's a valid namedtuple, but defer if it is not ready.
26792720
if not info:
26802721
self.mark_incomplete(name, lvalue, becomes_typeinfo=True)
2722+
else:
2723+
self.setup_type_vars(info.defn, tvar_defs)
2724+
self.setup_alias_type_vars(info.defn)
26812725
return True
26822726

26832727
def analyze_typeddict_assign(self, s: AssignmentStmt) -> bool:
@@ -5864,10 +5908,16 @@ def expr_to_analyzed_type(
58645908
self, expr: Expression, report_invalid_types: bool = True, allow_placeholder: bool = False
58655909
) -> Optional[Type]:
58665910
if isinstance(expr, CallExpr):
5911+
# This is a legacy syntax intended mostly for Python 2, we keep it for
5912+
# backwards compatibility, but new features like generic named tuples
5913+
# and recursive named tuples will be not supported.
58675914
expr.accept(self)
5868-
internal_name, info = self.named_tuple_analyzer.check_namedtuple(
5915+
internal_name, info, tvar_defs = self.named_tuple_analyzer.check_namedtuple(
58695916
expr, None, self.is_func_scope()
58705917
)
5918+
if tvar_defs:
5919+
self.fail("Generic named tuples are not supported for legacy class syntax", expr)
5920+
self.note("Use either Python 3 class syntax, or the assignment syntax", expr)
58715921
if internal_name is None:
58725922
# Some form of namedtuple is the only valid type that looks like a call
58735923
# expression. This isn't a valid type.

mypy/semanal_namedtuple.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@
5858
Type,
5959
TypeOfAny,
6060
TypeType,
61+
TypeVarLikeType,
6162
TypeVarType,
6263
UnboundType,
64+
has_type_vars,
6365
)
6466
from mypy.util import get_unique_redefinition_name
6567

@@ -118,7 +120,6 @@ def analyze_namedtuple_classdef(
118120
info = self.build_namedtuple_typeinfo(
119121
defn.name, items, types, default_items, defn.line, existing_info
120122
)
121-
defn.info = info
122123
defn.analyzed = NamedTupleExpr(info, is_typed=True)
123124
defn.analyzed.line = defn.line
124125
defn.analyzed.column = defn.column
@@ -201,7 +202,7 @@ def check_namedtuple_classdef(
201202

202203
def check_namedtuple(
203204
self, node: Expression, var_name: Optional[str], is_func_scope: bool
204-
) -> Tuple[Optional[str], Optional[TypeInfo]]:
205+
) -> Tuple[Optional[str], Optional[TypeInfo], List[TypeVarLikeType]]:
205206
"""Check if a call defines a namedtuple.
206207
207208
The optional var_name argument is the name of the variable to
@@ -216,21 +217,21 @@ def check_namedtuple(
216217
report errors but return (some) TypeInfo.
217218
"""
218219
if not isinstance(node, CallExpr):
219-
return None, None
220+
return None, None, []
220221
call = node
221222
callee = call.callee
222223
if not isinstance(callee, RefExpr):
223-
return None, None
224+
return None, None, []
224225
fullname = callee.fullname
225226
if fullname == "collections.namedtuple":
226227
is_typed = False
227228
elif fullname in TYPED_NAMEDTUPLE_NAMES:
228229
is_typed = True
229230
else:
230-
return None, None
231+
return None, None, []
231232
result = self.parse_namedtuple_args(call, fullname)
232233
if result:
233-
items, types, defaults, typename, ok = result
234+
items, types, defaults, typename, tvar_defs, ok = result
234235
else:
235236
# Error. Construct dummy return value.
236237
if var_name:
@@ -244,10 +245,10 @@ def check_namedtuple(
244245
if name != var_name or is_func_scope:
245246
# NOTE: we skip local namespaces since they are not serialized.
246247
self.api.add_symbol_skip_local(name, info)
247-
return var_name, info
248+
return var_name, info, []
248249
if not ok:
249250
# This is a valid named tuple but some types are not ready.
250-
return typename, None
251+
return typename, None, []
251252

252253
# We use the variable name as the class name if it exists. If
253254
# it doesn't, we use the name passed as an argument. We prefer
@@ -306,7 +307,7 @@ def check_namedtuple(
306307
if name != var_name or is_func_scope:
307308
# NOTE: we skip local namespaces since they are not serialized.
308309
self.api.add_symbol_skip_local(name, info)
309-
return typename, info
310+
return typename, info, tvar_defs
310311

311312
def store_namedtuple_info(
312313
self, info: TypeInfo, name: str, call: CallExpr, is_typed: bool
@@ -317,7 +318,9 @@ def store_namedtuple_info(
317318

318319
def parse_namedtuple_args(
319320
self, call: CallExpr, fullname: str
320-
) -> Optional[Tuple[List[str], List[Type], List[Expression], str, bool]]:
321+
) -> Optional[
322+
Tuple[List[str], List[Type], List[Expression], str, List[TypeVarLikeType], bool]
323+
]:
321324
"""Parse a namedtuple() call into data needed to construct a type.
322325
323326
Returns a 5-tuple:
@@ -363,6 +366,7 @@ def parse_namedtuple_args(
363366
return None
364367
typename = cast(StrExpr, call.args[0]).value
365368
types: List[Type] = []
369+
tvar_defs = []
366370
if not isinstance(args[1], (ListExpr, TupleExpr)):
367371
if fullname == "collections.namedtuple" and isinstance(args[1], StrExpr):
368372
str_expr = args[1]
@@ -384,14 +388,20 @@ def parse_namedtuple_args(
384388
return None
385389
items = [cast(StrExpr, item).value for item in listexpr.items]
386390
else:
391+
type_exprs = [
392+
t.items[1]
393+
for t in listexpr.items
394+
if isinstance(t, TupleExpr) and len(t.items) == 2
395+
]
396+
tvar_defs = self.api.get_and_bind_all_tvars(type_exprs)
387397
# The fields argument contains (name, type) tuples.
388398
result = self.parse_namedtuple_fields_with_types(listexpr.items, call)
389399
if result is None:
390400
# One of the types is not ready, defer.
391401
return None
392402
items, types, _, ok = result
393403
if not ok:
394-
return [], [], [], typename, False
404+
return [], [], [], typename, [], False
395405
if not types:
396406
types = [AnyType(TypeOfAny.unannotated) for _ in items]
397407
underscore = [item for item in items if item.startswith("_")]
@@ -404,7 +414,7 @@ def parse_namedtuple_args(
404414
if len(defaults) > len(items):
405415
self.fail(f'Too many defaults given in call to "{type_name}()"', call)
406416
defaults = defaults[: len(items)]
407-
return items, types, defaults, typename, True
417+
return items, types, defaults, typename, tvar_defs, True
408418

409419
def parse_namedtuple_fields_with_types(
410420
self, nodes: List[Expression], context: Context
@@ -490,7 +500,7 @@ def build_namedtuple_typeinfo(
490500
# We can't calculate the complete fallback type until after semantic
491501
# analysis, since otherwise base classes might be incomplete. Postpone a
492502
# callback function that patches the fallback.
493-
if not has_placeholder(tuple_base):
503+
if not has_placeholder(tuple_base) and not has_type_vars(tuple_base):
494504
self.api.schedule_patch(
495505
PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(tuple_base)
496506
)
@@ -525,7 +535,11 @@ def add_field(
525535

526536
assert info.tuple_type is not None # Set by update_tuple_type() above.
527537
tvd = TypeVarType(
528-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], info.tuple_type
538+
SELF_TVAR_NAME,
539+
info.fullname + "." + SELF_TVAR_NAME,
540+
self.api.tvar_scope.new_unique_func_id(),
541+
[],
542+
info.tuple_type,
529543
)
530544
selftype = tvd
531545

0 commit comments

Comments
 (0)