Skip to content

Commit 2fa3d79

Browse files
committed
fix #9179: generate dollarvalues for parameterized enums, use serialization proxy
1 parent 0668611 commit 2fa3d79

File tree

8 files changed

+74
-27
lines changed

8 files changed

+74
-27
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ object desugar {
505505
val targ = refOfDef(tparam)
506506
def fullyApplied(tparam: Tree): Tree = tparam match {
507507
case TypeDef(_, LambdaTypeTree(tparams, body)) =>
508-
AppliedTypeTree(targ, tparams.map(_ => TypeBoundsTree(EmptyTree, EmptyTree)))
508+
AppliedTypeTree(targ, tparams.map(_ => WildcardTypeBoundsTree()))
509509
case TypeDef(_, rhs: DerivedTypeTree) =>
510510
fullyApplied(rhs.watched)
511511
case _ =>

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ object DesugarEnums {
7777
private def valuesDot(name: PreName)(implicit src: SourceFile) =
7878
Select(Ident(nme.DOLLAR_VALUES), name.toTermName)
7979

80-
private def registerCall(using Context): List[Tree] =
81-
if (enumClass.typeParams.nonEmpty) Nil
82-
else Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil) :: Nil
80+
private def registerCall(using Context): Tree =
81+
val asRaw = TypeApply(Select(This(EmptyTypeIdent), nme.asInstanceOf_), rawRef(enumClass.typeRef) :: Nil) // safe to cast due to refchecks
82+
Apply(valuesDot("register"), asRaw :: Nil)
8383

8484
/** The following lists of definitions for an enum type E:
8585
*
@@ -93,12 +93,14 @@ object DesugarEnums {
9393
* }
9494
*/
9595
private def enumScaffolding(using Context): List[Tree] = {
96+
import dotty.tools.dotc.transform.SymUtils._
97+
val rawEnumClassRef = rawRef(enumClass.typeRef)
98+
extension (tpe: NamedType) def ofRawEnum = AppliedTypeTree(ref(tpe), rawEnumClassRef)
9699
val valuesDef =
97-
DefDef(nme.values, Nil, Nil, TypeTree(defn.ArrayOf(enumClass.typeRef)), Select(valuesDot(nme.values), nme.toArray))
100+
DefDef(nme.values, Nil, Nil, defn.ArrayType.ofRawEnum, Select(valuesDot(nme.values), nme.toArray))
98101
.withFlags(Synthetic)
99102
val privateValuesDef =
100-
ValDef(nme.DOLLAR_VALUES, TypeTree(),
101-
New(TypeTree(defn.EnumValuesClass.typeRef.appliedTo(enumClass.typeRef :: Nil)), ListOfNil))
103+
ValDef(nme.DOLLAR_VALUES, TypeTree(), New(defn.EnumValuesClass.typeRef.ofRawEnum, ListOfNil))
102104
.withFlags(Private | Synthetic)
103105

104106
val valuesOfExnMessage = Apply(
@@ -138,7 +140,7 @@ object DesugarEnums {
138140
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
139141
derived = Nil,
140142
self = EmptyValDef,
141-
body = List(ordinalDef, toStringDef) ++ registerCall
143+
body = ordinalDef :: toStringDef :: registerCall :: Nil
142144
).withAttachment(ExtendsSingletonMirror, ()))
143145
DefDef(nme.DOLLAR_NEW, Nil,
144146
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
@@ -254,7 +256,7 @@ object DesugarEnums {
254256
val minKind = if (kind < seenKind) kind else seenKind
255257
ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind))
256258
val scaffolding =
257-
if (enumClass.typeParams.nonEmpty || kind >= seenKind) Nil
259+
if (kind >= seenKind) Nil
258260
else if (kind == CaseKind.Object) enumScaffolding
259261
else if (seenKind == CaseKind.Object) enumValueCreator :: Nil
260262
else enumScaffolding :+ enumValueCreator
@@ -288,8 +290,8 @@ object DesugarEnums {
288290
val toStringDef = toStringMethLit(name.toString)
289291
val impl1 = cpy.Template(impl)(
290292
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
291-
body = List(ordinalDef, toStringDef) ++ registerCall)
292-
.withAttachment(ExtendsSingletonMirror, ())
293+
body = ordinalDef :: toStringDef :: registerCall :: Nil
294+
).withAttachment(ExtendsSingletonMirror, ())
293295
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
294296
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
295297
}

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
150150
override def isEmpty: Boolean = true
151151
}
152152

153+
def WildcardTypeBoundsTree()(using src: SourceFile): TypeBoundsTree = TypeBoundsTree(EmptyTree, EmptyTree, EmptyTree)
154+
object WildcardTypeBoundsTree:
155+
def unapply(tree: untpd.Tree): Boolean = tree match
156+
case TypeBoundsTree(EmptyTree, EmptyTree, _) => true
157+
case _ => false
158+
159+
153160
/** A block generated by the XML parser, only treated specially by
154161
* `Positioned#checkPos` */
155162
class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr)
@@ -453,6 +460,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
453460
def ref(tp: NamedType)(using Context): Tree =
454461
TypedSplice(tpd.ref(tp))
455462

463+
def rawRef(tp: NamedType)(using Context): Tree =
464+
if tp.typeParams.isEmpty then ref(tp)
465+
else AppliedTypeTree(ref(tp), tp.typeParams.map(_ => WildcardTypeBoundsTree()))
466+
456467
def rootDot(name: Name)(implicit src: SourceFile): Select = Select(Ident(nme.ROOTPKG), name)
457468
def scalaDot(name: Name)(implicit src: SourceFile): Select = Select(rootDot(nme.scala), name)
458469
def scalaAnnotationDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.annotation), name)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,12 @@ class Definitions {
651651
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
652652

653653
@tu lazy val EnumValuesClass: ClassSymbol = requiredClass("scala.runtime.EnumValues")
654+
@tu lazy val EnumValueClass: ClassSymbol = requiredClass("scala.runtime.EnumValue")
655+
656+
@tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy")
657+
@tu lazy val EnumValueSerializationProxyConstructor: TermSymbol =
658+
EnumValueSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty)))
659+
654660
@tu lazy val ProductClass: ClassSymbol = requiredClass("scala.Product")
655661
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
656662
@tu lazy val Product_productArity : Symbol = ProductClass.requiredMethod(nme.productArity)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1491,7 +1491,7 @@ object Parsers {
14911491
}
14921492

14931493
private def makeKindProjectorTypeDef(name: TypeName): TypeDef =
1494-
TypeDef(name, TypeBoundsTree(EmptyTree, EmptyTree)).withFlags(Param)
1494+
TypeDef(name, WildcardTypeBoundsTree()).withFlags(Param)
14951495

14961496
/** Replaces kind-projector's `*` in a list of types arguments with synthetic names,
14971497
* returning the new argument list and the synthetic type definitions.

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -355,31 +355,59 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
355355
symbolsToSynthesize.flatMap(syntheticDefIfMissing)
356356
}
357357

358+
private def hasWriteReplace(clazz: ClassSymbol)(using Context): Boolean =
359+
clazz.membersNamed(nme.writeReplace)
360+
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
361+
.exists
362+
363+
private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol =
364+
newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
365+
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
366+
358367
/** If this is a serializable static object `Foo`, add the method:
359368
*
360369
* private def writeReplace(): AnyRef =
361370
* new scala.runtime.ModuleSerializationProxy(classOf[Foo.type])
362371
*
363372
* unless an implementation already exists, otherwise do nothing.
364373
*/
365-
def serializableObjectMethod(clazz: ClassSymbol)(using Context): List[Tree] = {
366-
def hasWriteReplace: Boolean =
367-
clazz.membersNamed(nme.writeReplace)
368-
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
369-
.exists
370-
if (clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace) {
371-
val writeReplace = newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
372-
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
374+
def serializableObjectMethod(clazz: ClassSymbol)(using Context): List[Tree] =
375+
if clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace(clazz) then
373376
List(
374-
DefDef(writeReplace,
377+
DefDef(writeReplaceDef(clazz),
375378
_ => New(defn.ModuleSerializationProxyClass.typeRef,
376379
defn.ModuleSerializationProxyConstructor,
377380
List(Literal(Constant(clazz.sourceModule.termRef)))))
378381
.withSpan(ctx.owner.span.focus))
379-
}
380382
else
381383
Nil
382-
}
384+
385+
/** does this class extend `scala.runtime.EnumValue` and derive from an enum definition? */
386+
extension (sym: ClassSymbol) private def isEnumValueImplementation(using Context): Boolean =
387+
derivesFrom(defn.EnumValueClass) && classParents.exists(_.typeSymbol.flags.is(Enum))
388+
389+
/** If this the class backing a serializable singleton enum value with base class `MyEnum`,
390+
* and not deriving from `java.lang.Enum` add the method:
391+
*
392+
* private def writeReplace(): AnyRef =
393+
* new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal)
394+
*
395+
* unless an implementation already exists, otherwise do nothing.
396+
*/
397+
def serializableEnumValueMethod(clazz: ClassSymbol)(using Context): List[Tree] =
398+
if clazz.isEnumValueImplementation
399+
&& !clazz.derivesFrom(defn.JavaEnumClass)
400+
&& clazz.isSerializable
401+
&& !hasWriteReplace(clazz)
402+
then
403+
List(
404+
DefDef(writeReplaceDef(clazz),
405+
_ => New(defn.EnumValueSerializationProxyClass.typeRef,
406+
defn.EnumValueSerializationProxyConstructor,
407+
List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied)))
408+
.withSpan(ctx.owner.span.focus))
409+
else
410+
Nil
383411

384412
/** The class
385413
*
@@ -528,6 +556,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
528556
def addSyntheticMembers(impl: Template)(using Context): Template = {
529557
val clazz = ctx.owner.asClass
530558
addMirrorSupport(
531-
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body))
559+
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: serializableEnumValueMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body))
532560
}
533561
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,7 +1688,7 @@ class Typer extends Namer
16881688
}
16891689
if (desugaredArg.isType)
16901690
arg match {
1691-
case TypeBoundsTree(EmptyTree, EmptyTree, _)
1691+
case untpd.WildcardTypeBoundsTree()
16921692
if tparam.paramInfo.isLambdaSub &&
16931693
tpt1.tpe.typeParamSymbols.nonEmpty &&
16941694
!ctx.mode.is(Mode.Pattern) =>
@@ -1707,7 +1707,7 @@ class Typer extends Namer
17071707
args.zipWithConserve(tparams)(typedArg(_, _)).asInstanceOf[List[Tree]]
17081708
}
17091709
val paramBounds = tparams.lazyZip(args).map {
1710-
case (tparam, TypeBoundsTree(EmptyTree, EmptyTree, _)) =>
1710+
case (tparam, untpd.WildcardTypeBoundsTree()) =>
17111711
// if type argument is a wildcard, suppress kind checking since
17121712
// there is no real argument.
17131713
NoType

tests/neg/i6601.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ object GADTs2 {
77
case Lit[G](n: Int) extends Expr[G, Int]
88
// case S[A, G](x:
99
}
10-
}
10+
}

0 commit comments

Comments
 (0)