Skip to content

Commit 7ba2303

Browse files
committed
Introduce tracked class parameters
For a tracked class parameter we add a refinement in the constructor type that the class member is the same as the parameter. E.g. ```scala class C { type T } class D(tracked val x: C) { type T = x.T } ``` This will generate the constructor type: ```scala (x1: C): D { val x: x1.type } ``` Without `tracked` the refinement would not be added. This can solve several problems with dependent class types where previously we lost track of type dependencies.
1 parent 8acb696 commit 7ba2303

39 files changed

+736
-58
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+24-11
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,13 @@ object desugar {
443443
private def toDefParam(tparam: TypeDef, keepAnnotations: Boolean): TypeDef = {
444444
var mods = tparam.rawMods
445445
if (!keepAnnotations) mods = mods.withAnnotations(Nil)
446-
tparam.withMods(mods & (EmptyFlags | Sealed) | Param)
446+
tparam.withMods(mods & EmptyFlags | Param)
447447
}
448448
private def toDefParam(vparam: ValDef, keepAnnotations: Boolean, keepDefault: Boolean): ValDef = {
449449
var mods = vparam.rawMods
450450
if (!keepAnnotations) mods = mods.withAnnotations(Nil)
451451
val hasDefault = if keepDefault then HasDefault else EmptyFlags
452-
vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault) | Param)
452+
vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault | Tracked) | Param)
453453
}
454454

455455
def mkApply(fn: Tree, paramss: List[ParamClause])(using Context): Tree =
@@ -535,7 +535,7 @@ object desugar {
535535
// but not on the constructor parameters. The reverse is true for
536536
// annotations on class _value_ parameters.
537537
val constrTparams = impliedTparams.map(toDefParam(_, keepAnnotations = false))
538-
val constrVparamss =
538+
def defVparamss =
539539
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
540540
if (isCaseClass)
541541
report.error(CaseClassMissingParamList(cdef), namePos)
@@ -546,6 +546,10 @@ object desugar {
546546
ListOfNil
547547
}
548548
else originalVparamss.nestedMap(toDefParam(_, keepAnnotations = true, keepDefault = true))
549+
val constrVparamss = defVparamss
550+
// defVparamss also needed as separate tree nodes in implicitWrappers below.
551+
// Need to be separate because they are `watch`ed in addParamRefinements.
552+
// See parsercombinators-givens.scala for a test case.
549553
val derivedTparams =
550554
constrTparams.zipWithConserve(impliedTparams)((tparam, impliedParam) =>
551555
derivedTypeParam(tparam).withAnnotations(impliedParam.mods.annotations))
@@ -623,6 +627,11 @@ object desugar {
623627
case _ => false
624628
}
625629

630+
def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match {
631+
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
632+
case _ => false
633+
}
634+
626635
def appliedRef(tycon: Tree, tparams: List[TypeDef] = constrTparams, widenHK: Boolean = false) = {
627636
val targs = for (tparam <- tparams) yield {
628637
val targ = refOfDef(tparam)
@@ -639,10 +648,13 @@ object desugar {
639648
appliedTypeTree(tycon, targs)
640649
}
641650

642-
def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match {
643-
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
644-
case _ => false
645-
}
651+
def addParamRefinements(core: Tree, paramss: List[List[ValDef]]): Tree =
652+
val refinements =
653+
for params <- paramss; param <- params; if param.mods.is(Tracked) yield
654+
ValDef(param.name, SingletonTypeTree(TermRefTree().watching(param)), EmptyTree)
655+
.withSpan(param.span)
656+
if refinements.isEmpty then core
657+
else RefinedTypeTree(core, refinements).showing(i"refined result: $result", Printers.desugar)
646658

647659
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
648660
val classTypeRef = appliedRef(classTycon)
@@ -863,18 +875,17 @@ object desugar {
863875
Nil
864876
}
865877
else {
866-
val defParamss = constrVparamss match {
878+
val defParamss = defVparamss match
867879
case Nil :: paramss =>
868880
paramss // drop leading () that got inserted by class
869881
// TODO: drop this once we do not silently insert empty class parameters anymore
870882
case paramss => paramss
871-
}
872883
val finalFlag = if ctx.settings.YcompileScala2Library.value then EmptyFlags else Final
873884
// implicit wrapper is typechecked in same scope as constructor, so
874885
// we can reuse the constructor parameters; no derived params are needed.
875886
DefDef(
876887
className.toTermName, joinParams(constrTparams, defParamss),
877-
classTypeRef, creatorExpr)
888+
addParamRefinements(classTypeRef, defParamss), creatorExpr)
878889
.withMods(companionMods | mods.flags.toTermFlags & (GivenOrImplicit | Inline) | finalFlag)
879890
.withSpan(cdef.span) :: Nil
880891
}
@@ -903,7 +914,9 @@ object desugar {
903914
}
904915
if mods.isAllOf(Given | Inline | Transparent) then
905916
report.error("inline given instances cannot be trasparent", cdef)
906-
val classMods = if mods.is(Given) then mods &~ (Inline | Transparent) | Synthetic else mods
917+
var classMods = if mods.is(Given) then mods &~ (Inline | Transparent) | Synthetic else mods
918+
if vparamAccessors.exists(_.mods.is(Tracked)) then
919+
classMods |= Dependent
907920
cpy.TypeDef(cdef: TypeDef)(
908921
name = className,
909922
rhs = cpy.Template(impl)(constr, parents1, clsDerived, self1,

compiler/src/dotty/tools/dotc/ast/untpd.scala

+2
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
231231

232232
case class Infix()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Infix)
233233

234+
case class Tracked()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Tracked)
235+
234236
/** Used under pureFunctions to mark impure function types `A => B` in `FunctionWithMods` */
235237
case class Impure()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Impure)
236238
}

compiler/src/dotty/tools/dotc/core/Flags.scala

+7-4
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ object Flags {
242242
val (AccessorOrSealed @ _, Accessor @ _, Sealed @ _) = newFlags(11, "<accessor>", "sealed")
243243

244244
/** A mutable var, an open class */
245-
val (MutableOrOpen @ __, Mutable @ _, Open @ _) = newFlags(12, "mutable", "open")
245+
val (MutableOrOpen @ _, Mutable @ _, Open @ _) = newFlags(12, "mutable", "open")
246246

247247
/** Symbol is local to current class (i.e. private[this] or protected[this]
248248
* pre: Private or Protected are also set
@@ -377,6 +377,9 @@ object Flags {
377377
/** Symbol cannot be found as a member during typer */
378378
val (Invisible @ _, _, _) = newFlags(45, "<invisible>")
379379

380+
/** Tracked modifier for class parameter / a class with some tracked parameters */
381+
val (Tracked @ _, _, Dependent @ _) = newFlags(46, "tracked")
382+
380383
// ------------ Flags following this one are not pickled ----------------------------------
381384

382385
/** Symbol is not a member of its owner */
@@ -452,7 +455,7 @@ object Flags {
452455
CommonSourceModifierFlags.toTypeFlags | Abstract | Sealed | Opaque | Open
453456

454457
val TermSourceModifierFlags: FlagSet =
455-
CommonSourceModifierFlags.toTermFlags | Inline | AbsOverride | Lazy
458+
CommonSourceModifierFlags.toTermFlags | Inline | AbsOverride | Lazy | Tracked
456459

457460
/** Flags representing modifiers that can appear in trees */
458461
val ModifierFlags: FlagSet =
@@ -466,7 +469,7 @@ object Flags {
466469
val FromStartFlags: FlagSet = commonFlags(
467470
Module, Package, Deferred, Method, Case, Enum, Param, ParamAccessor,
468471
Scala2SpecialFlags, MutableOrOpen, Opaque, Touched, JavaStatic,
469-
OuterOrCovariant, LabelOrContravariant, CaseAccessor,
472+
OuterOrCovariant, LabelOrContravariant, CaseAccessor, Tracked,
470473
Extension, NonMember, Implicit, Given, Permanent, Synthetic, Exported,
471474
SuperParamAliasOrScala2x, Inline, Macro, ConstructorProxy, Invisible)
472475

@@ -477,7 +480,7 @@ object Flags {
477480
*/
478481
val AfterLoadFlags: FlagSet = commonFlags(
479482
FromStartFlags, AccessFlags, Final, AccessorOrSealed,
480-
Abstract, LazyOrTrait, SelfName, JavaDefined, JavaAnnotation, Transparent)
483+
Abstract, LazyOrTrait, SelfName, JavaDefined, JavaAnnotation, Transparent, Tracked)
481484

482485
/** A value that's unstable unless complemented with a Stable flag */
483486
val UnstableValueFlags: FlagSet = Mutable | Method

compiler/src/dotty/tools/dotc/core/NamerOps.scala

+8-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@ object NamerOps:
1515
* @param ctor the constructor
1616
*/
1717
def effectiveResultType(ctor: Symbol, paramss: List[List[Symbol]])(using Context): Type =
18-
paramss match
19-
case TypeSymbols(tparams) :: _ => ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef))
20-
case _ => ctor.owner.typeRef
18+
val (resType, termParamss) = paramss match
19+
case TypeSymbols(tparams) :: rest =>
20+
(ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef)), rest)
21+
case _ =>
22+
(ctor.owner.typeRef, paramss)
23+
termParamss.flatten.foldLeft(resType): (rt, param) =>
24+
if param.is(Tracked) then RefinedType(rt, param.name, param.termRef)
25+
else rt
2126

2227
/** Split dependent class refinements off parent type. Add them to `refinements`,
2328
* unless it is null.

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

+2-7
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,6 @@ trait PatternTypeConstrainer { self: TypeComparer =>
8888
}
8989
}
9090

91-
def stripRefinement(tp: Type): Type = tp match {
92-
case tp: RefinedOrRecType => stripRefinement(tp.parent)
93-
case tp => tp
94-
}
95-
9691
def tryConstrainSimplePatternType(pat: Type, scrut: Type) = {
9792
val patCls = pat.classSymbol
9893
val scrCls = scrut.classSymbol
@@ -181,14 +176,14 @@ trait PatternTypeConstrainer { self: TypeComparer =>
181176
case AndType(scrut1, scrut2) =>
182177
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
183178
case scrut: RefinedOrRecType =>
184-
constrainPatternType(pat, stripRefinement(scrut))
179+
constrainPatternType(pat, scrut.stripRefinement)
185180
case scrut => dealiasDropNonmoduleRefs(pat) match {
186181
case OrType(pat1, pat2) =>
187182
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
188183
case AndType(pat1, pat2) =>
189184
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
190185
case pat: RefinedOrRecType =>
191-
constrainPatternType(stripRefinement(pat), scrut)
186+
constrainPatternType(pat.stripRefinement, scrut)
192187
case pat =>
193188
tryConstrainSimplePatternType(pat, scrut)
194189
|| classesMayBeCompatible && constrainUpcasted(scrut)

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ object StdNames {
623623
val toString_ : N = "toString"
624624
val toTypeConstructor: N = "toTypeConstructor"
625625
val tpe : N = "tpe"
626+
val tracked: N = "tracked"
626627
val transparent : N = "transparent"
627628
val tree : N = "tree"
628629
val true_ : N = "true"

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+6-2
Original file line numberDiff line numberDiff line change
@@ -1185,13 +1185,17 @@ object SymDenotations {
11851185
final def isExtensibleClass(using Context): Boolean =
11861186
isClass && !isOneOf(FinalOrModuleClass) && !isAnonymousClass
11871187

1188-
/** A symbol is effectively final if it cannot be overridden in a subclass */
1188+
/** A symbol is effectively final if it cannot be overridden */
11891189
final def isEffectivelyFinal(using Context): Boolean =
11901190
isOneOf(EffectivelyFinalFlags)
11911191
|| is(Inline, butNot = Deferred)
11921192
|| is(JavaDefinedVal, butNot = Method)
11931193
|| isConstructor
1194-
|| !owner.isExtensibleClass
1194+
|| !owner.isExtensibleClass && !is(Deferred)
1195+
// Deferred symbols can arise through parent refinements.
1196+
// For them, the overriding relationship reverses anyway, so
1197+
// being in a final class does not mean the symbol cannot be
1198+
// implemented concretely in a superclass.
11951199

11961200
/** A class is effectively sealed if has the `final` or `sealed` modifier, or it
11971201
* is defined in Scala 3 and is neither abstract nor open.

compiler/src/dotty/tools/dotc/core/TypeUtils.scala

+10-4
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ import TypeErasure.ErasedValueType
66
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
77
import Names.Name
88

9-
class TypeUtils {
9+
class TypeUtils:
1010
/** A decorator that provides methods on types
1111
* that are needed in the transformer pipeline.
1212
*/
13-
extension (self: Type) {
13+
extension (self: Type)
1414

1515
def isErasedValueType(using Context): Boolean =
1616
self.isInstanceOf[ErasedValueType]
@@ -150,5 +150,11 @@ class TypeUtils {
150150
case _ =>
151151
val cls = self.underlyingClassRef(refinementOK = false).typeSymbol
152152
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
153-
}
154-
}
153+
154+
/** Strip all outer refinements off this type */
155+
def stripRefinement: Type = self match
156+
case self: RefinedOrRecType => self.parent.stripRefinement
157+
case seld => self
158+
159+
end TypeUtils
160+

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
798798
if (flags.is(Exported)) writeModTag(EXPORTED)
799799
if (flags.is(Given)) writeModTag(GIVEN)
800800
if (flags.is(Implicit)) writeModTag(IMPLICIT)
801+
if (flags.is(Tracked)) writeModTag(TRACKED)
801802
if (isTerm) {
802803
if (flags.is(Lazy, butNot = Module)) writeModTag(LAZY)
803804
if (flags.is(AbsOverride)) { writeModTag(ABSTRACT); writeModTag(OVERRIDE) }

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ class TreeUnpickler(reader: TastyReader,
737737
case INVISIBLE => addFlag(Invisible)
738738
case TRANSPARENT => addFlag(Transparent)
739739
case INFIX => addFlag(Infix)
740+
case TRACKED => addFlag(Tracked)
740741
case PRIVATEqualified =>
741742
readByte()
742743
privateWithin = readWithin

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+10-5
Original file line numberDiff line numberDiff line change
@@ -3104,6 +3104,7 @@ object Parsers {
31043104
case nme.open => Mod.Open()
31053105
case nme.transparent => Mod.Transparent()
31063106
case nme.infix => Mod.Infix()
3107+
case nme.tracked => Mod.Tracked()
31073108
}
31083109
}
31093110

@@ -3169,7 +3170,8 @@ object Parsers {
31693170
* | AccessModifier
31703171
* | override
31713172
* | opaque
3172-
* LocalModifier ::= abstract | final | sealed | open | implicit | lazy | erased | inline | transparent
3173+
* LocalModifier ::= abstract | final | sealed | open | implicit | lazy | erased |
3174+
* inline | transparent
31733175
*/
31743176
def modifiers(allowed: BitSet = modifierTokens, start: Modifiers = Modifiers()): Modifiers = {
31753177
@tailrec
@@ -3323,7 +3325,7 @@ object Parsers {
33233325
* UsingClsTermParamClause::= ‘(’ ‘using’ [‘erased’] (ClsParams | ContextTypes) ‘)’
33243326
* ClsParams ::= ClsParam {‘,’ ClsParam}
33253327
* ClsParam ::= {Annotation}
3326-
*
3328+
* [{Modifier | ‘tracked’} (‘val’ | ‘var’) | ‘inline’] Param
33273329
* TypelessClause ::= DefTermParamClause
33283330
* | UsingParamClause
33293331
*
@@ -3359,6 +3361,8 @@ object Parsers {
33593361
if isErasedKw then
33603362
mods = addModifier(mods)
33613363
if paramOwner.isClass then
3364+
if isIdent(nme.tracked) && in.featureEnabled(Feature.modularity) && !in.lookahead.isColon then
3365+
mods = addModifier(mods)
33623366
mods = addFlag(modifiers(start = mods), ParamAccessor)
33633367
mods =
33643368
if in.token == VAL then
@@ -3430,7 +3434,8 @@ object Parsers {
34303434
val isParams =
34313435
!impliedMods.is(Given)
34323436
|| startParamTokens.contains(in.token)
3433-
|| isIdent && (in.name == nme.inline || in.lookahead.isColon)
3437+
|| isIdent
3438+
&& (in.name == nme.inline || in.name == nme.tracked || in.lookahead.isColon)
34343439
(mods, isParams)
34353440
(if isParams then commaSeparated(() => param())
34363441
else contextTypes(paramOwner, numLeadParams, impliedMods)) match {
@@ -4010,8 +4015,8 @@ object Parsers {
40104015
def adjustDefParams(paramss: List[ParamClause]): List[ParamClause] =
40114016
paramss.nestedMap: param =>
40124017
if !param.mods.isAllOf(PrivateLocal) then
4013-
syntaxError(em"method parameter ${param.name} may not be `a val`", param.span)
4014-
param.withMods(param.mods &~ (AccessFlags | ParamAccessor | Mutable) | Param)
4018+
syntaxError(em"method parameter ${param.name} may not be a `val`", param.span)
4019+
param.withMods(param.mods &~ (AccessFlags | ParamAccessor | Tracked | Mutable) | Param)
40154020
.asInstanceOf[List[ParamClause]]
40164021

40174022
val gdef =

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
111111
protected def refinementNameString(tp: RefinedType): String = nameString(tp.refinedName)
112112

113113
/** String representation of a refinement */
114-
protected def toTextRefinement(rt: RefinedType): Text =
114+
def toTextRefinement(rt: RefinedType): Text =
115115
val keyword = rt.refinedInfo match {
116116
case _: ExprType | _: MethodOrPoly => "def "
117117
case _: TypeBounds => "type "

compiler/src/dotty/tools/dotc/printing/Printer.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package printing
44

55
import core.*
66
import Texts.*, ast.Trees.*
7-
import Types.{Type, SingletonType, LambdaParam, NamedType},
7+
import Types.{Type, SingletonType, LambdaParam, NamedType, RefinedType},
88
Symbols.Symbol, Scopes.Scope, Constants.Constant,
99
Names.Name, Denotations._, Annotations.Annotation, Contexts.Context
1010
import typer.Implicits.*
@@ -104,6 +104,9 @@ abstract class Printer {
104104
/** Textual representation of a prefix of some reference, ending in `.` or `#` */
105105
def toTextPrefixOf(tp: NamedType): Text
106106

107+
/** textual representation of a refinement, with no enclosing {...} */
108+
def toTextRefinement(rt: RefinedType): Text
109+
107110
/** Textual representation of a reference in a capture set */
108111
def toTextCaptureRef(tp: Type): Text
109112

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

+12-4
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,15 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
335335
case Select(nu: New, nme.CONSTRUCTOR) if isCheckable(nu) =>
336336
// need to check instantiability here, because the type of the New itself
337337
// might be a type constructor.
338-
ctx.typer.checkClassType(tree.tpe, tree.srcPos, traitReq = false, stablePrefixReq = true)
338+
def checkClassType(tpe: Type, stablePrefixReq: Boolean) =
339+
ctx.typer.checkClassType(tpe, tree.srcPos,
340+
traitReq = false, stablePrefixReq = stablePrefixReq,
341+
refinementOK = Feature.enabled(Feature.modularity))
342+
checkClassType(tree.tpe, true)
339343
if !nu.tpe.isLambdaSub then
340344
// Check the constructor type as well; it could be an illegal singleton type
341345
// which would not be reflected as `tree.tpe`
342-
ctx.typer.checkClassType(nu.tpe, tree.srcPos, traitReq = false, stablePrefixReq = false)
346+
checkClassType(nu.tpe, false)
343347
Checking.checkInstantiable(tree.tpe, nu.tpe, nu.srcPos)
344348
withNoCheckNews(nu :: Nil)(app1)
345349
case _ =>
@@ -415,8 +419,12 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
415419
// Constructor parameters are in scope when typing a parent.
416420
// While they can safely appear in a parent tree, to preserve
417421
// soundness we need to ensure they don't appear in a parent
418-
// type (#16270).
419-
val illegalRefs = parent.tpe.namedPartsWith(p => p.symbol.is(ParamAccessor) && (p.symbol.owner eq sym))
422+
// type (#16270). We can strip any refinement of a parent type since
423+
// these refinements are split off from the parent type constructor
424+
// application `parent` in Namer and don't show up as parent types
425+
// of the class.
426+
val illegalRefs = parent.tpe.stripRefinement.namedPartsWith:
427+
p => p.symbol.is(ParamAccessor) && (p.symbol.owner eq sym)
420428
if illegalRefs.nonEmpty then
421429
report.error(
422430
em"The type of a class parent cannot refer to constructor parameters, but ${parent.tpe} refers to ${illegalRefs.map(_.name.show).mkString(",")}", parent.srcPos)

0 commit comments

Comments
 (0)