Skip to content

Commit f4dfc7d

Browse files
authored
Merge pull request #9532 from dotty-staging/topic/enum-serialization
fix #9179: ensure enum values are singleton with serialisation
2 parents d5efc05 + 0872aaf commit f4dfc7d

File tree

16 files changed

+270
-59
lines changed

16 files changed

+270
-59
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ object desugar {
505505
val targ = refOfDef(tparam)
506506
def fullyApplied(tparam: Tree): Tree = tparam match {
507507
case TypeDef(_, LambdaTypeTree(tparams, body)) =>
508-
AppliedTypeTree(targ, tparams.map(_ => TypeBoundsTree(EmptyTree, EmptyTree)))
508+
AppliedTypeTree(targ, tparams.map(_ => WildcardTypeBoundsTree()))
509509
case TypeDef(_, rhs: DerivedTypeTree) =>
510510
fullyApplied(rhs.watched)
511511
case _ =>

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

+10-10
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ object DesugarEnums {
7777
private def valuesDot(name: PreName)(implicit src: SourceFile) =
7878
Select(Ident(nme.DOLLAR_VALUES), name.toTermName)
7979

80-
private def registerCall(using Context): List[Tree] =
81-
if (enumClass.typeParams.nonEmpty) Nil
82-
else Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil) :: Nil
80+
private def registerCall(using Context): Tree =
81+
Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil)
8382

8483
/** The following lists of definitions for an enum type E:
8584
*
@@ -93,12 +92,13 @@ object DesugarEnums {
9392
* }
9493
*/
9594
private def enumScaffolding(using Context): List[Tree] = {
95+
val rawEnumClassRef = rawRef(enumClass.typeRef)
96+
extension (tpe: NamedType) def ofRawEnum = AppliedTypeTree(ref(tpe), rawEnumClassRef)
9697
val valuesDef =
97-
DefDef(nme.values, Nil, Nil, TypeTree(defn.ArrayOf(enumClass.typeRef)), Select(valuesDot(nme.values), nme.toArray))
98+
DefDef(nme.values, Nil, Nil, defn.ArrayType.ofRawEnum, Select(valuesDot(nme.values), nme.toArray))
9899
.withFlags(Synthetic)
99100
val privateValuesDef =
100-
ValDef(nme.DOLLAR_VALUES, TypeTree(),
101-
New(TypeTree(defn.EnumValuesClass.typeRef.appliedTo(enumClass.typeRef :: Nil)), ListOfNil))
101+
ValDef(nme.DOLLAR_VALUES, TypeTree(), New(defn.EnumValuesClass.typeRef.ofRawEnum, ListOfNil))
102102
.withFlags(Private | Synthetic)
103103

104104
val valuesOfExnMessage = Apply(
@@ -138,7 +138,7 @@ object DesugarEnums {
138138
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
139139
derived = Nil,
140140
self = EmptyValDef,
141-
body = List(ordinalDef, toStringDef) ++ registerCall
141+
body = ordinalDef :: toStringDef :: registerCall :: Nil
142142
).withAttachment(ExtendsSingletonMirror, ()))
143143
DefDef(nme.DOLLAR_NEW, Nil,
144144
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
@@ -254,7 +254,7 @@ object DesugarEnums {
254254
val minKind = if (kind < seenKind) kind else seenKind
255255
ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind))
256256
val scaffolding =
257-
if (enumClass.typeParams.nonEmpty || kind >= seenKind) Nil
257+
if (kind >= seenKind) Nil
258258
else if (kind == CaseKind.Object) enumScaffolding
259259
else if (seenKind == CaseKind.Object) enumValueCreator :: Nil
260260
else enumScaffolding :+ enumValueCreator
@@ -288,8 +288,8 @@ object DesugarEnums {
288288
val toStringDef = toStringMethLit(name.toString)
289289
val impl1 = cpy.Template(impl)(
290290
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
291-
body = List(ordinalDef, toStringDef) ++ registerCall)
292-
.withAttachment(ExtendsSingletonMirror, ())
291+
body = ordinalDef :: toStringDef :: registerCall :: Nil
292+
).withAttachment(ExtendsSingletonMirror, ())
293293
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
294294
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
295295
}

compiler/src/dotty/tools/dotc/ast/untpd.scala

+11
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
150150
override def isEmpty: Boolean = true
151151
}
152152

153+
def WildcardTypeBoundsTree()(using src: SourceFile): TypeBoundsTree = TypeBoundsTree(EmptyTree, EmptyTree, EmptyTree)
154+
object WildcardTypeBoundsTree:
155+
def unapply(tree: untpd.Tree): Boolean = tree match
156+
case TypeBoundsTree(EmptyTree, EmptyTree, _) => true
157+
case _ => false
158+
159+
153160
/** A block generated by the XML parser, only treated specially by
154161
* `Positioned#checkPos` */
155162
class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr)
@@ -453,6 +460,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
453460
def ref(tp: NamedType)(using Context): Tree =
454461
TypedSplice(tpd.ref(tp))
455462

463+
def rawRef(tp: NamedType)(using Context): Tree =
464+
if tp.typeParams.isEmpty then ref(tp)
465+
else AppliedTypeTree(ref(tp), tp.typeParams.map(_ => WildcardTypeBoundsTree()))
466+
456467
def rootDot(name: Name)(implicit src: SourceFile): Select = Select(Ident(nme.ROOTPKG), name)
457468
def scalaDot(name: Name)(implicit src: SourceFile): Select = Select(rootDot(nme.scala), name)
458469
def scalaAnnotationDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.annotation), name)

compiler/src/dotty/tools/dotc/core/Definitions.scala

+5
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,11 @@ class Definitions {
651651
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
652652

653653
@tu lazy val EnumValuesClass: ClassSymbol = requiredClass("scala.runtime.EnumValues")
654+
655+
@tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy")
656+
@tu lazy val EnumValueSerializationProxyConstructor: TermSymbol =
657+
EnumValueSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty), IntType))
658+
654659
@tu lazy val ProductClass: ClassSymbol = requiredClass("scala.Product")
655660
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
656661
@tu lazy val Product_productArity : Symbol = ProductClass.requiredMethod(nme.productArity)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1491,7 +1491,7 @@ object Parsers {
14911491
}
14921492

14931493
private def makeKindProjectorTypeDef(name: TypeName): TypeDef =
1494-
TypeDef(name, TypeBoundsTree(EmptyTree, EmptyTree)).withFlags(Param)
1494+
TypeDef(name, WildcardTypeBoundsTree()).withFlags(Param)
14951495

14961496
/** Replaces kind-projector's `*` in a list of types arguments with synthetic names,
14971497
* returning the new argument list and the synthetic type definitions.

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
294294
cpy.Inlined(tree)(callTrace, transformSub(bindings), transform(expansion)(using inlineContext(call)))
295295
case templ: Template =>
296296
withNoCheckNews(templ.parents.flatMap(newPart)) {
297-
Checking.checkEnumParentOK(templ.symbol.owner)
298297
forwardParamAccessors(templ)
299298
synthMbr.addSyntheticMembers(
300299
superAcc.wrapTemplate(templ)(

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

+40-12
Original file line numberDiff line numberDiff line change
@@ -355,31 +355,59 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
355355
symbolsToSynthesize.flatMap(syntheticDefIfMissing)
356356
}
357357

358+
private def hasWriteReplace(clazz: ClassSymbol)(using Context): Boolean =
359+
clazz.membersNamed(nme.writeReplace)
360+
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
361+
.exists
362+
363+
private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol =
364+
newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
365+
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
366+
358367
/** If this is a serializable static object `Foo`, add the method:
359368
*
360369
* private def writeReplace(): AnyRef =
361370
* new scala.runtime.ModuleSerializationProxy(classOf[Foo.type])
362371
*
363372
* unless an implementation already exists, otherwise do nothing.
364373
*/
365-
def serializableObjectMethod(clazz: ClassSymbol)(using Context): List[Tree] = {
366-
def hasWriteReplace: Boolean =
367-
clazz.membersNamed(nme.writeReplace)
368-
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
369-
.exists
370-
if (clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace) {
371-
val writeReplace = newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
372-
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
374+
def serializableObjectMethod(clazz: ClassSymbol)(using Context): List[Tree] =
375+
if clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace(clazz) then
373376
List(
374-
DefDef(writeReplace,
377+
DefDef(writeReplaceDef(clazz),
375378
_ => New(defn.ModuleSerializationProxyClass.typeRef,
376379
defn.ModuleSerializationProxyConstructor,
377380
List(Literal(Constant(clazz.sourceModule.termRef)))))
378381
.withSpan(ctx.owner.span.focus))
379-
}
380382
else
381383
Nil
382-
}
384+
385+
/** Is this an anonymous class deriving from an enum definition? */
386+
extension (cls: ClassSymbol) private def isEnumValueImplementation(using Context): Boolean =
387+
isAnonymousClass && classParents.head.typeSymbol.is(Enum) // asserted in Typer
388+
389+
/** If this is the class backing a serializable singleton enum value with base class `MyEnum`,
390+
* and not deriving from `java.lang.Enum` add the method:
391+
*
392+
* private def writeReplace(): AnyRef =
393+
* new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal)
394+
*
395+
* unless an implementation already exists, otherwise do nothing.
396+
*/
397+
def serializableEnumValueMethod(clazz: ClassSymbol)(using Context): List[Tree] =
398+
if clazz.isEnumValueImplementation
399+
&& !clazz.derivesFrom(defn.JavaEnumClass)
400+
&& clazz.isSerializable
401+
&& !hasWriteReplace(clazz)
402+
then
403+
List(
404+
DefDef(writeReplaceDef(clazz),
405+
_ => New(defn.EnumValueSerializationProxyClass.typeRef,
406+
defn.EnumValueSerializationProxyConstructor,
407+
List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied)))
408+
.withSpan(ctx.owner.span.focus))
409+
else
410+
Nil
383411

384412
/** The class
385413
*
@@ -528,6 +556,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
528556
def addSyntheticMembers(impl: Template)(using Context): Template = {
529557
val clazz = ctx.owner.asClass
530558
addMirrorSupport(
531-
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body))
559+
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: serializableEnumValueMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body))
532560
}
533561
}

compiler/src/dotty/tools/dotc/typer/Checking.scala

+33-14
Original file line numberDiff line numberDiff line change
@@ -630,17 +630,6 @@ object Checking {
630630
}
631631
}
632632

633-
/** Check that an enum case extends its enum class */
634-
def checkEnumParentOK(cls: Symbol)(using Context): Unit =
635-
val enumCase =
636-
if cls.isAllOf(EnumCase) then cls
637-
else if cls.isAnonymousClass && cls.owner.isAllOf(EnumCase) then cls.owner
638-
else NoSymbol
639-
if enumCase.exists then
640-
val enumCls = enumCase.owner.linkedClass
641-
if !cls.info.parents.exists(_.typeSymbol == enumCls) then
642-
report.error(i"enum case does not extend its enum $enumCls", enumCase.srcPos)
643-
644633
/** Check the inline override methods only use inline parameters if they override an inline parameter. */
645634
def checkInlineOverrideParameters(sym: Symbol)(using Context): Unit =
646635
lazy val params = sym.paramSymss.flatten
@@ -1095,9 +1084,10 @@ trait Checking {
10951084
*/
10961085
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = {
10971086
def isEnumAnonCls =
1098-
cls.isAnonymousClass &&
1099-
cls.owner.isTerm &&
1100-
(cls.owner.flagsUNSAFE.is(Case) || cls.owner.name == nme.DOLLAR_NEW)
1087+
cls.isAnonymousClass
1088+
&& cls.owner.isTerm
1089+
&& (cls.owner.flagsUNSAFE.isAllOf(EnumCase)
1090+
|| ((cls.owner.name eq nme.DOLLAR_NEW) && cls.owner.flagsUNSAFE.isAllOf(Private | Synthetic)))
11011091
if (!isEnumAnonCls)
11021092
if (cdef.mods.isEnumCase) {
11031093
if (cls.derivesFrom(defn.JavaEnumClass))
@@ -1112,6 +1102,34 @@ trait Checking {
11121102
report.error(ClassCannotExtendEnum(cls, firstParent), cdef.srcPos)
11131103
}
11141104

1105+
/** Check that the firstParent for an enum case derives from the declaring enum class, if not, adds it as a parent
1106+
* after emitting an error.
1107+
*
1108+
* This check will have no effect on simple enum cases as their parents are inferred by the compiler.
1109+
*/
1110+
def checkEnumParent(cls: Symbol, firstParent: Symbol)(using Context): Unit =
1111+
1112+
extension (sym: Symbol) def typeRefApplied(using Context): Type =
1113+
typeRef.appliedTo(typeParams.map(_.info.loBound))
1114+
1115+
def ensureParentDerivesFrom(enumCase: Symbol)(using Context) =
1116+
val enumCls = enumCase.owner.linkedClass
1117+
if !firstParent.derivesFrom(enumCls) then
1118+
report.error(i"enum case does not extend its enum $enumCls", enumCase.srcPos)
1119+
cls.info match
1120+
case info: ClassInfo =>
1121+
cls.info = info.derivedClassInfo(classParents = enumCls.typeRefApplied :: info.classParents)
1122+
case _ =>
1123+
1124+
val enumCase =
1125+
if cls.flagsUNSAFE.isAllOf(EnumCase) then cls
1126+
else if cls.isAnonymousClass && cls.owner.flagsUNSAFE.isAllOf(EnumCase) then cls.owner
1127+
else NoSymbol
1128+
if enumCase.exists then
1129+
ensureParentDerivesFrom(enumCase)
1130+
1131+
end checkEnumParent
1132+
11151133
/** Check that all references coming from enum cases in an enum companion object
11161134
* are legal.
11171135
* @param cdef the enum companion object class
@@ -1205,6 +1223,7 @@ trait Checking {
12051223

12061224
trait ReChecking extends Checking {
12071225
import tpd._
1226+
override def checkEnumParent(cls: Symbol, firstParent: Symbol)(using Context): Unit = ()
12081227
override def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = ()
12091228
override def checkRefsLegal(tree: tpd.Tree, badOwner: Symbol, allowed: (Name, Symbol) => Boolean, where: String)(using Context): Unit = ()
12101229
override def checkFullyAppliedType(tree: Tree)(using Context): Unit = ()

compiler/src/dotty/tools/dotc/typer/Typer.scala

+7-7
Original file line numberDiff line numberDiff line change
@@ -1720,7 +1720,7 @@ class Typer extends Namer
17201720
}
17211721
if (desugaredArg.isType)
17221722
arg match {
1723-
case TypeBoundsTree(EmptyTree, EmptyTree, _)
1723+
case untpd.WildcardTypeBoundsTree()
17241724
if tparam.paramInfo.isLambdaSub &&
17251725
tpt1.tpe.typeParamSymbols.nonEmpty &&
17261726
!ctx.mode.is(Mode.Pattern) =>
@@ -1739,7 +1739,7 @@ class Typer extends Namer
17391739
args.zipWithConserve(tparams)(typedArg(_, _)).asInstanceOf[List[Tree]]
17401740
}
17411741
val paramBounds = tparams.lazyZip(args).map {
1742-
case (tparam, TypeBoundsTree(EmptyTree, EmptyTree, _)) =>
1742+
case (tparam, untpd.WildcardTypeBoundsTree()) =>
17431743
// if type argument is a wildcard, suppress kind checking since
17441744
// there is no real argument.
17451745
NoType
@@ -2102,25 +2102,25 @@ class Typer extends Namer
21022102
val constr1 = typed(constr).asInstanceOf[DefDef]
21032103
val parentsWithClass = ensureFirstTreeIsClass(parents.mapconserve(typedParent).filterConserve(!_.isEmpty), cdef.nameSpan)
21042104
val parents1 = ensureConstrCall(cls, parentsWithClass)(using superCtx)
2105+
val firstParent = parents1.head.tpe.dealias.typeSymbol
2106+
2107+
checkEnumParent(cls, firstParent)
21052108

21062109
val self1 = typed(self)(using ctx.outer).asInstanceOf[ValDef] // outer context where class members are not visible
21072110
if (self1.tpt.tpe.isError || classExistsOnSelf(cls.unforcedDecls, self1))
21082111
// fail fast to avoid typing the body with an error type
21092112
cdef.withType(UnspecifiedErrorType)
21102113
else {
21112114
val dummy = localDummy(cls, impl)
2112-
val body1 = addAccessorDefs(cls,
2113-
typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)
2115+
val body1 = addAccessorDefs(cls, typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)
21142116

21152117
checkNoDoubleDeclaration(cls)
21162118
val impl1 = cpy.Template(impl)(constr1, parents1, Nil, self1, body1)
21172119
.withType(dummy.termRef)
21182120
if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper)
21192121
checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan))
2120-
if (cls.derivesFrom(defn.EnumClass)) {
2121-
val firstParent = parents1.head.tpe.dealias.typeSymbol
2122+
if cls.derivesFrom(defn.EnumClass) then
21222123
checkEnum(cdef, cls, firstParent)
2123-
}
21242124
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)
21252125

21262126
val reportDynamicInheritance =

docs/docs/reference/enums/desugarEnums.md

+9-10
Original file line numberDiff line numberDiff line change
@@ -156,31 +156,30 @@ map into `case class`es or `val`s.
156156
a type parameter of the case, i.e. the parameter name is defined in `<params>`.
157157

158158

159-
### Translation of Enumerations
159+
### Translation of Enums with Singleton Cases
160160

161-
Non-generic enums `E` that define one or more singleton cases
162-
are called _enumerations_. Companion objects of enumerations define
163-
the following additional synthetic members.
161+
An enum `E` (possibly generic) that defines one or more singleton cases
162+
will define the following additional synthetic members in its companion object (where `E'` denotes `E` with
163+
any type parameters replaced by wildcards):
164164

165-
- A method `valueOf(name: String): E`. It returns the singleton case value whose
165+
- A method `valueOf(name: String): E'`. It returns the singleton case value whose
166166
`toString` representation is `name`.
167-
- A method `values` which returns an `Array[E]` of all singleton case
168-
values in `E`, in the order of their definitions.
167+
- A method `values` which returns an `Array[E']` of all singleton case
168+
values defined by `E`, in the order of their definitions.
169169

170-
Companion objects of enumerations that contain at least one simple case define in addition:
170+
If `E` contains at least one simple case, its companion object will define in addition:
171171

172172
- A private method `$new` which defines a new simple case value with given
173173
ordinal number and name. This method can be thought as being defined as
174174
follows.
175175
```scala
176-
private def $new(_$ordinal: Int, $name: String) = new E {
176+
private def $new(_$ordinal: Int, $name: String) = new E with runtime.EnumValue {
177177
def $ordinal = $_ordinal
178178
override def toString = $name
179179
$values.register(this) // register enum value so that `valueOf` and `values` can return it.
180180
}
181181
```
182182

183-
The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
184183
The `$ordinal` method above is used to generate the `ordinal` method if the enum does not extend a `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it.
185184

186185
### Scopes for Enum Cases

0 commit comments

Comments
 (0)