Skip to content

fix #9179: ensure enum values are singleton with serialisation #9532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ object desugar {
val targ = refOfDef(tparam)
def fullyApplied(tparam: Tree): Tree = tparam match {
case TypeDef(_, LambdaTypeTree(tparams, body)) =>
AppliedTypeTree(targ, tparams.map(_ => TypeBoundsTree(EmptyTree, EmptyTree)))
AppliedTypeTree(targ, tparams.map(_ => WildcardTypeBoundsTree()))
case TypeDef(_, rhs: DerivedTypeTree) =>
fullyApplied(rhs.watched)
case _ =>
Expand Down
20 changes: 10 additions & 10 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ object DesugarEnums {
private def valuesDot(name: PreName)(implicit src: SourceFile) =
Select(Ident(nme.DOLLAR_VALUES), name.toTermName)

private def registerCall(using Context): List[Tree] =
if (enumClass.typeParams.nonEmpty) Nil
else Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil) :: Nil
private def registerCall(using Context): Tree =
Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil)

/** The following lists of definitions for an enum type E:
*
Expand All @@ -93,12 +92,13 @@ object DesugarEnums {
* }
*/
private def enumScaffolding(using Context): List[Tree] = {
val rawEnumClassRef = rawRef(enumClass.typeRef)
extension (tpe: NamedType) def ofRawEnum = AppliedTypeTree(ref(tpe), rawEnumClassRef)
val valuesDef =
DefDef(nme.values, Nil, Nil, TypeTree(defn.ArrayOf(enumClass.typeRef)), Select(valuesDot(nme.values), nme.toArray))
DefDef(nme.values, Nil, Nil, defn.ArrayType.ofRawEnum, Select(valuesDot(nme.values), nme.toArray))
.withFlags(Synthetic)
val privateValuesDef =
ValDef(nme.DOLLAR_VALUES, TypeTree(),
New(TypeTree(defn.EnumValuesClass.typeRef.appliedTo(enumClass.typeRef :: Nil)), ListOfNil))
ValDef(nme.DOLLAR_VALUES, TypeTree(), New(defn.EnumValuesClass.typeRef.ofRawEnum, ListOfNil))
.withFlags(Private | Synthetic)

val valuesOfExnMessage = Apply(
Expand Down Expand Up @@ -138,7 +138,7 @@ object DesugarEnums {
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
derived = Nil,
self = EmptyValDef,
body = List(ordinalDef, toStringDef) ++ registerCall
body = ordinalDef :: toStringDef :: registerCall :: Nil
).withAttachment(ExtendsSingletonMirror, ()))
DefDef(nme.DOLLAR_NEW, Nil,
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
Expand Down Expand Up @@ -254,7 +254,7 @@ object DesugarEnums {
val minKind = if (kind < seenKind) kind else seenKind
ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind))
val scaffolding =
if (enumClass.typeParams.nonEmpty || kind >= seenKind) Nil
if (kind >= seenKind) Nil
else if (kind == CaseKind.Object) enumScaffolding
else if (seenKind == CaseKind.Object) enumValueCreator :: Nil
else enumScaffolding :+ enumValueCreator
Expand Down Expand Up @@ -288,8 +288,8 @@ object DesugarEnums {
val toStringDef = toStringMethLit(name.toString)
val impl1 = cpy.Template(impl)(
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
body = List(ordinalDef, toStringDef) ++ registerCall)
.withAttachment(ExtendsSingletonMirror, ())
body = ordinalDef :: toStringDef :: registerCall :: Nil
).withAttachment(ExtendsSingletonMirror, ())
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
}
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
override def isEmpty: Boolean = true
}

def WildcardTypeBoundsTree()(using src: SourceFile): TypeBoundsTree = TypeBoundsTree(EmptyTree, EmptyTree, EmptyTree)
object WildcardTypeBoundsTree:
def unapply(tree: untpd.Tree): Boolean = tree match
case TypeBoundsTree(EmptyTree, EmptyTree, _) => true
case _ => false


/** A block generated by the XML parser, only treated specially by
* `Positioned#checkPos` */
class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr)
Expand Down Expand Up @@ -453,6 +460,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def ref(tp: NamedType)(using Context): Tree =
TypedSplice(tpd.ref(tp))

def rawRef(tp: NamedType)(using Context): Tree =
if tp.typeParams.isEmpty then ref(tp)
else AppliedTypeTree(ref(tp), tp.typeParams.map(_ => WildcardTypeBoundsTree()))

def rootDot(name: Name)(implicit src: SourceFile): Select = Select(Ident(nme.ROOTPKG), name)
def scalaDot(name: Name)(implicit src: SourceFile): Select = Select(rootDot(nme.scala), name)
def scalaAnnotationDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.annotation), name)
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,11 @@ class Definitions {
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)

@tu lazy val EnumValuesClass: ClassSymbol = requiredClass("scala.runtime.EnumValues")

@tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy")
@tu lazy val EnumValueSerializationProxyConstructor: TermSymbol =
EnumValueSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty), IntType))

@tu lazy val ProductClass: ClassSymbol = requiredClass("scala.Product")
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
@tu lazy val Product_productArity : Symbol = ProductClass.requiredMethod(nme.productArity)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,7 @@ object Parsers {
}

private def makeKindProjectorTypeDef(name: TypeName): TypeDef =
TypeDef(name, TypeBoundsTree(EmptyTree, EmptyTree)).withFlags(Param)
TypeDef(name, WildcardTypeBoundsTree()).withFlags(Param)

/** Replaces kind-projector's `*` in a list of types arguments with synthetic names,
* returning the new argument list and the synthetic type definitions.
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
cpy.Inlined(tree)(callTrace, transformSub(bindings), transform(expansion)(using inlineContext(call)))
case templ: Template =>
withNoCheckNews(templ.parents.flatMap(newPart)) {
Checking.checkEnumParentOK(templ.symbol.owner)
forwardParamAccessors(templ)
synthMbr.addSyntheticMembers(
superAcc.wrapTemplate(templ)(
Expand Down
52 changes: 40 additions & 12 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -355,31 +355,59 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
symbolsToSynthesize.flatMap(syntheticDefIfMissing)
}

private def hasWriteReplace(clazz: ClassSymbol)(using Context): Boolean =
clazz.membersNamed(nme.writeReplace)
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
.exists

private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol =
newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm

/** If this is a serializable static object `Foo`, add the method:
*
* private def writeReplace(): AnyRef =
* new scala.runtime.ModuleSerializationProxy(classOf[Foo.type])
*
* unless an implementation already exists, otherwise do nothing.
*/
def serializableObjectMethod(clazz: ClassSymbol)(using Context): List[Tree] = {
def hasWriteReplace: Boolean =
clazz.membersNamed(nme.writeReplace)
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
.exists
if (clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace) {
val writeReplace = newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
def serializableObjectMethod(clazz: ClassSymbol)(using Context): List[Tree] =
if clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace(clazz) then
List(
DefDef(writeReplace,
DefDef(writeReplaceDef(clazz),
_ => New(defn.ModuleSerializationProxyClass.typeRef,
defn.ModuleSerializationProxyConstructor,
List(Literal(Constant(clazz.sourceModule.termRef)))))
.withSpan(ctx.owner.span.focus))
}
else
Nil
}

/** Is this an anonymous class deriving from an enum definition? */
extension (cls: ClassSymbol) private def isEnumValueImplementation(using Context): Boolean =
isAnonymousClass && classParents.head.typeSymbol.is(Enum) // asserted in Typer

/** If this is the class backing a serializable singleton enum value with base class `MyEnum`,
* and not deriving from `java.lang.Enum` add the method:
*
* private def writeReplace(): AnyRef =
* new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal)
*
* unless an implementation already exists, otherwise do nothing.
*/
def serializableEnumValueMethod(clazz: ClassSymbol)(using Context): List[Tree] =
if clazz.isEnumValueImplementation
&& !clazz.derivesFrom(defn.JavaEnumClass)
&& clazz.isSerializable
&& !hasWriteReplace(clazz)
then
List(
DefDef(writeReplaceDef(clazz),
_ => New(defn.EnumValueSerializationProxyClass.typeRef,
defn.EnumValueSerializationProxyConstructor,
List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied)))
.withSpan(ctx.owner.span.focus))
else
Nil

/** The class
*
Expand Down Expand Up @@ -528,6 +556,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
def addSyntheticMembers(impl: Template)(using Context): Template = {
val clazz = ctx.owner.asClass
addMirrorSupport(
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body))
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: serializableEnumValueMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body))
}
}
47 changes: 33 additions & 14 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -630,17 +630,6 @@ object Checking {
}
}

/** Check that an enum case extends its enum class */
def checkEnumParentOK(cls: Symbol)(using Context): Unit =
val enumCase =
if cls.isAllOf(EnumCase) then cls
else if cls.isAnonymousClass && cls.owner.isAllOf(EnumCase) then cls.owner
else NoSymbol
if enumCase.exists then
val enumCls = enumCase.owner.linkedClass
if !cls.info.parents.exists(_.typeSymbol == enumCls) then
report.error(i"enum case does not extend its enum $enumCls", enumCase.srcPos)

/** Check the inline override methods only use inline parameters if they override an inline parameter. */
def checkInlineOverrideParameters(sym: Symbol)(using Context): Unit =
lazy val params = sym.paramSymss.flatten
Expand Down Expand Up @@ -1095,9 +1084,10 @@ trait Checking {
*/
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = {
def isEnumAnonCls =
cls.isAnonymousClass &&
cls.owner.isTerm &&
(cls.owner.flagsUNSAFE.is(Case) || cls.owner.name == nme.DOLLAR_NEW)
cls.isAnonymousClass
&& cls.owner.isTerm
&& (cls.owner.flagsUNSAFE.isAllOf(EnumCase)
|| ((cls.owner.name eq nme.DOLLAR_NEW) && cls.owner.flagsUNSAFE.isAllOf(Private | Synthetic)))
if (!isEnumAnonCls)
if (cdef.mods.isEnumCase) {
if (cls.derivesFrom(defn.JavaEnumClass))
Expand All @@ -1112,6 +1102,34 @@ trait Checking {
report.error(ClassCannotExtendEnum(cls, firstParent), cdef.srcPos)
}

/** Check that the firstParent for an enum case derives from the declaring enum class, if not, adds it as a parent
* after emitting an error.
*
* This check will have no effect on simple enum cases as their parents are inferred by the compiler.
*/
def checkEnumParent(cls: Symbol, firstParent: Symbol)(using Context): Unit =

extension (sym: Symbol) def typeRefApplied(using Context): Type =
typeRef.appliedTo(typeParams.map(_.info.loBound))

def ensureParentDerivesFrom(enumCase: Symbol)(using Context) =
val enumCls = enumCase.owner.linkedClass
if !firstParent.derivesFrom(enumCls) then
report.error(i"enum case does not extend its enum $enumCls", enumCase.srcPos)
cls.info match
case info: ClassInfo =>
cls.info = info.derivedClassInfo(classParents = enumCls.typeRefApplied :: info.classParents)
case _ =>

val enumCase =
if cls.flagsUNSAFE.isAllOf(EnumCase) then cls
else if cls.isAnonymousClass && cls.owner.flagsUNSAFE.isAllOf(EnumCase) then cls.owner
else NoSymbol
if enumCase.exists then
ensureParentDerivesFrom(enumCase)

end checkEnumParent

/** Check that all references coming from enum cases in an enum companion object
* are legal.
* @param cdef the enum companion object class
Expand Down Expand Up @@ -1205,6 +1223,7 @@ trait Checking {

trait ReChecking extends Checking {
import tpd._
override def checkEnumParent(cls: Symbol, firstParent: Symbol)(using Context): Unit = ()
override def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = ()
override def checkRefsLegal(tree: tpd.Tree, badOwner: Symbol, allowed: (Name, Symbol) => Boolean, where: String)(using Context): Unit = ()
override def checkFullyAppliedType(tree: Tree)(using Context): Unit = ()
Expand Down
14 changes: 7 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1720,7 +1720,7 @@ class Typer extends Namer
}
if (desugaredArg.isType)
arg match {
case TypeBoundsTree(EmptyTree, EmptyTree, _)
case untpd.WildcardTypeBoundsTree()
if tparam.paramInfo.isLambdaSub &&
tpt1.tpe.typeParamSymbols.nonEmpty &&
!ctx.mode.is(Mode.Pattern) =>
Expand All @@ -1739,7 +1739,7 @@ class Typer extends Namer
args.zipWithConserve(tparams)(typedArg(_, _)).asInstanceOf[List[Tree]]
}
val paramBounds = tparams.lazyZip(args).map {
case (tparam, TypeBoundsTree(EmptyTree, EmptyTree, _)) =>
case (tparam, untpd.WildcardTypeBoundsTree()) =>
// if type argument is a wildcard, suppress kind checking since
// there is no real argument.
NoType
Expand Down Expand Up @@ -2102,25 +2102,25 @@ class Typer extends Namer
val constr1 = typed(constr).asInstanceOf[DefDef]
val parentsWithClass = ensureFirstTreeIsClass(parents.mapconserve(typedParent).filterConserve(!_.isEmpty), cdef.nameSpan)
val parents1 = ensureConstrCall(cls, parentsWithClass)(using superCtx)
val firstParent = parents1.head.tpe.dealias.typeSymbol

checkEnumParent(cls, firstParent)

val self1 = typed(self)(using ctx.outer).asInstanceOf[ValDef] // outer context where class members are not visible
if (self1.tpt.tpe.isError || classExistsOnSelf(cls.unforcedDecls, self1))
// fail fast to avoid typing the body with an error type
cdef.withType(UnspecifiedErrorType)
else {
val dummy = localDummy(cls, impl)
val body1 = addAccessorDefs(cls,
typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)
val body1 = addAccessorDefs(cls, typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)

checkNoDoubleDeclaration(cls)
val impl1 = cpy.Template(impl)(constr1, parents1, Nil, self1, body1)
.withType(dummy.termRef)
if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper)
checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan))
if (cls.derivesFrom(defn.EnumClass)) {
val firstParent = parents1.head.tpe.dealias.typeSymbol
if cls.derivesFrom(defn.EnumClass) then
checkEnum(cdef, cls, firstParent)
}
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)

val reportDynamicInheritance =
Expand Down
19 changes: 9 additions & 10 deletions docs/docs/reference/enums/desugarEnums.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,31 +156,30 @@ map into `case class`es or `val`s.
a type parameter of the case, i.e. the parameter name is defined in `<params>`.


### Translation of Enumerations
### Translation of Enums with Singleton Cases

Non-generic enums `E` that define one or more singleton cases
are called _enumerations_. Companion objects of enumerations define
the following additional synthetic members.
An enum `E` (possibly generic) that defines one or more singleton cases
will define the following additional synthetic members in its companion object (where `E'` denotes `E` with
any type parameters replaced by wildcards):

- A method `valueOf(name: String): E`. It returns the singleton case value whose
- A method `valueOf(name: String): E'`. It returns the singleton case value whose
`toString` representation is `name`.
- A method `values` which returns an `Array[E]` of all singleton case
values in `E`, in the order of their definitions.
- A method `values` which returns an `Array[E']` of all singleton case
values defined by `E`, in the order of their definitions.

Companion objects of enumerations that contain at least one simple case define in addition:
If `E` contains at least one simple case, its companion object will define in addition:

- A private method `$new` which defines a new simple case value with given
ordinal number and name. This method can be thought as being defined as
follows.
```scala
private def $new(_$ordinal: Int, $name: String) = new E {
private def $new(_$ordinal: Int, $name: String) = new E with runtime.EnumValue {
def $ordinal = $_ordinal
override def toString = $name
$values.register(this) // register enum value so that `valueOf` and `values` can return it.
}
```

The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
The `$ordinal` method above is used to generate the `ordinal` method if the enum does not extend a `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it.

### Scopes for Enum Cases
Expand Down
Loading