Skip to content

Commit c00bb5d

Browse files
committed
fix #9179: generate dollarvalues for parameterized enums, use serialization proxy
1 parent 0668611 commit c00bb5d

File tree

4 files changed

+57
-20
lines changed

4 files changed

+57
-20
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ object DesugarEnums {
7878
Select(Ident(nme.DOLLAR_VALUES), name.toTermName)
7979

8080
private def registerCall(using Context): List[Tree] =
81-
if (enumClass.typeParams.nonEmpty) Nil
82-
else Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil) :: Nil
81+
Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil) :: Nil
8382

8483
/** The following lists of definitions for an enum type E:
8584
*
@@ -93,12 +92,13 @@ object DesugarEnums {
9392
* }
9493
*/
9594
private def enumScaffolding(using Context): List[Tree] = {
95+
import dotty.tools.dotc.transform.SymUtils._
9696
val valuesDef =
97-
DefDef(nme.values, Nil, Nil, TypeTree(defn.ArrayOf(enumClass.typeRef)), Select(valuesDot(nme.values), nme.toArray))
97+
DefDef(nme.values, Nil, Nil, TypeTree(defn.ArrayOf(enumClass.typeRef_*)), Select(valuesDot(nme.values), nme.toArray))
9898
.withFlags(Synthetic)
9999
val privateValuesDef =
100100
ValDef(nme.DOLLAR_VALUES, TypeTree(),
101-
New(TypeTree(defn.EnumValuesClass.typeRef.appliedTo(enumClass.typeRef :: Nil)), ListOfNil))
101+
New(TypeTree(defn.EnumValuesClass.typeRef.appliedTo(enumClass.typeRef_* :: Nil)), ListOfNil))
102102
.withFlags(Private | Synthetic)
103103

104104
val valuesOfExnMessage = Apply(
@@ -138,7 +138,7 @@ object DesugarEnums {
138138
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
139139
derived = Nil,
140140
self = EmptyValDef,
141-
body = List(ordinalDef, toStringDef) ++ registerCall
141+
body = ordinalDef :: toStringDef :: registerCall
142142
).withAttachment(ExtendsSingletonMirror, ()))
143143
DefDef(nme.DOLLAR_NEW, Nil,
144144
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
@@ -254,7 +254,7 @@ object DesugarEnums {
254254
val minKind = if (kind < seenKind) kind else seenKind
255255
ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind))
256256
val scaffolding =
257-
if (enumClass.typeParams.nonEmpty || kind >= seenKind) Nil
257+
if (kind >= seenKind) Nil
258258
else if (kind == CaseKind.Object) enumScaffolding
259259
else if (seenKind == CaseKind.Object) enumValueCreator :: Nil
260260
else enumScaffolding :+ enumValueCreator
@@ -288,8 +288,8 @@ object DesugarEnums {
288288
val toStringDef = toStringMethLit(name.toString)
289289
val impl1 = cpy.Template(impl)(
290290
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
291-
body = List(ordinalDef, toStringDef) ++ registerCall)
292-
.withAttachment(ExtendsSingletonMirror, ())
291+
body = ordinalDef :: toStringDef :: registerCall
292+
).withAttachment(ExtendsSingletonMirror, ())
293293
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
294294
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
295295
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,12 @@ class Definitions {
651651
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
652652

653653
@tu lazy val EnumValuesClass: ClassSymbol = requiredClass("scala.runtime.EnumValues")
654+
@tu lazy val EnumValueClass: ClassSymbol = requiredClass("scala.runtime.EnumValue")
655+
656+
@tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy")
657+
@tu lazy val EnumValueSerializationProxyConstructor: TermSymbol =
658+
EnumValueSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty)))
659+
654660
@tu lazy val ProductClass: ClassSymbol = requiredClass("scala.Product")
655661
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
656662
@tu lazy val Product_productArity : Symbol = ProductClass.requiredMethod(nme.productArity)

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ object SymUtils {
201201
def rawTypeRef(using Context) =
202202
self.typeRef.appliedTo(self.typeParams.map(_ => TypeBounds.emptyPolyKind))
203203

204+
def typeRef_*(using Context) =
205+
self.typeRef.appliedTo(self.typeParams.map(_ => TypeBounds.empty))
206+
204207
/** Is symbol a quote operation? */
205208
def isQuote(using Context): Boolean =
206209
self == defn.InternalQuoted_exprQuote || self == defn.QuotedTypeModule_apply

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -355,31 +355,59 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
355355
symbolsToSynthesize.flatMap(syntheticDefIfMissing)
356356
}
357357

358+
private def hasWriteReplace(clazz: ClassSymbol)(using Context): Boolean =
359+
clazz.membersNamed(nme.writeReplace)
360+
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
361+
.exists
362+
363+
private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol =
364+
newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
365+
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
366+
358367
/** If this is a serializable static object `Foo`, add the method:
359368
*
360369
* private def writeReplace(): AnyRef =
361370
* new scala.runtime.ModuleSerializationProxy(classOf[Foo.type])
362371
*
363372
* unless an implementation already exists, otherwise do nothing.
364373
*/
365-
def serializableObjectMethod(clazz: ClassSymbol)(using Context): List[Tree] = {
366-
def hasWriteReplace: Boolean =
367-
clazz.membersNamed(nme.writeReplace)
368-
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
369-
.exists
370-
if (clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace) {
371-
val writeReplace = newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
372-
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
374+
def serializableObjectMethod(clazz: ClassSymbol)(using Context): List[Tree] =
375+
if clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace(clazz) then
373376
List(
374-
DefDef(writeReplace,
377+
DefDef(writeReplaceDef(clazz),
375378
_ => New(defn.ModuleSerializationProxyClass.typeRef,
376379
defn.ModuleSerializationProxyConstructor,
377380
List(Literal(Constant(clazz.sourceModule.termRef)))))
378381
.withSpan(ctx.owner.span.focus))
379-
}
380382
else
381383
Nil
382-
}
384+
385+
/** does this class extend `scala.runtime.EnumValue` and derive from an enum definition? */
386+
extension (sym: ClassSymbol) def isEnumValueImplementation(using Context): Boolean =
387+
derivesFrom(defn.EnumValueClass) && classParents.exists(_.typeSymbol.flags.is(Enum))
388+
389+
/** If this the class backing a serializable singleton enum value with base class `MyEnum`,
390+
* and not deriving from `java.lang.Enum` add the method:
391+
*
392+
* private def writeReplace(): AnyRef =
393+
* new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal)
394+
*
395+
* unless an implementation already exists, otherwise do nothing.
396+
*/
397+
def serializableEnumValueMethod(clazz: ClassSymbol)(using Context): List[Tree] =
398+
if clazz.isEnumValueImplementation
399+
&& !clazz.derivesFrom(defn.JavaEnumClass)
400+
&& clazz.isSerializable
401+
&& !hasWriteReplace(clazz)
402+
then
403+
List(
404+
DefDef(writeReplaceDef(clazz),
405+
_ => New(defn.EnumValueSerializationProxyClass.typeRef,
406+
defn.EnumValueSerializationProxyConstructor,
407+
List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied)))
408+
.withSpan(ctx.owner.span.focus))
409+
else
410+
Nil
383411

384412
/** The class
385413
*
@@ -528,6 +556,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
528556
def addSyntheticMembers(impl: Template)(using Context): Template = {
529557
val clazz = ctx.owner.asClass
530558
addMirrorSupport(
531-
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body))
559+
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: serializableEnumValueMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body))
532560
}
533561
}

0 commit comments

Comments
 (0)