Skip to content

implement readResolve in terms of fromOrdinalDollar method #9612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,8 @@ object desugar {
val (enumCases, enumStats) = stats.partition(DesugarEnums.isEnumCase)
if (enumCases.isEmpty)
report.error(EnumerationsShouldNotBeEmpty(cdef), namePos)
else
enumCases.last.pushAttachment(DesugarEnums.DefinesEnumLookupMethods, ())
val enumCompanionRef = TermRefTree()
val enumImport =
Import(enumCompanionRef, enumCases.flatMap(caseIds).map(ImportSelector(_)))
Expand Down Expand Up @@ -568,7 +570,7 @@ object desugar {
// Note: copy default parameters need @uncheckedVariance; see
// neg/t1843-variances.scala for a test case. The test would give
// two errors without @uncheckedVariance, one of them spurious.
val caseClassMeths = {
val (caseClassMeths, enumScaffolding) = {
def syntheticProperty(name: TermName, tpt: Tree, rhs: Tree) =
DefDef(name, Nil, Nil, tpt, rhs).withMods(synthetic)

Expand All @@ -586,9 +588,11 @@ object desugar {
yield syntheticProperty(selName, caseParams(i).tpt,
Select(This(EmptyTypeIdent), caseParams(i).name))

def enumMeths =
if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: enumLabelLit(className.toString) :: Nil
else Nil
def enumCaseMeths =
if isEnumCase then
val (ordinal, scaffolding) = nextOrdinal(className, CaseKind.Class, definesEnumLookupMethods(cdef))
(ordinalMethLit(ordinal) :: enumLabelLit(className.toString) :: Nil, scaffolding)
else (Nil, Nil)
def copyMeths = {
val hasRepeatedParam = constrVparamss.exists(_.exists {
case ValDef(_, tpt, _) => isRepeated(tpt)
Expand All @@ -607,8 +611,9 @@ object desugar {
}

if (isCaseClass)
copyMeths ::: enumMeths ::: productElemMeths
else Nil
val (enumMeths, enumScaffolding) = enumCaseMeths
(copyMeths ::: enumMeths ::: productElemMeths, enumScaffolding)
else (Nil, Nil)
}

var parents1 = parents
Expand Down Expand Up @@ -809,7 +814,7 @@ object desugar {
case _ =>
}

flatTree(cdef1 :: companions ::: implicitWrappers)
flatTree(cdef1 :: companions ::: implicitWrappers ::: enumScaffolding)
}.reporting(i"desugared: $result", Printers.desugar)

/** Expand
Expand Down Expand Up @@ -862,7 +867,7 @@ object desugar {
else if (isEnumCase) {
typeParamIsReferenced(enumClass.typeParams, Nil, Nil, impl.parents)
// used to check there are no illegal references to enum's type parameters in parents
expandEnumModule(moduleName, impl, mods, mdef.span)
expandEnumModule(moduleName, impl, mods, definesEnumLookupMethods(mdef), mdef.span)
}
else {
val clsName = moduleName.moduleClassName
Expand Down Expand Up @@ -990,6 +995,12 @@ object desugar {

private def inventTypeName(tree: Tree)(using Context): String = typeNameExtractor("", tree)

/**This will check if this def tree is marked to define enum lookup methods,
* this is not recommended to call more than once per tree
*/
private def definesEnumLookupMethods(ddef: DefTree): Boolean =
ddef.removeAttachment(DefinesEnumLookupMethods).isDefined

/** val p1, ..., pN: T = E
* ==>
* makePatDef[[val p1: T1 = E]]; ...; makePatDef[[val pN: TN = E]]
Expand All @@ -1001,11 +1012,15 @@ object desugar {
def patDef(pdef: PatDef)(using Context): Tree = flatTree {
val PatDef(mods, pats, tpt, rhs) = pdef
if (mods.isEnumCase)
pats map {
case id: Ident =>
expandSimpleEnumCase(id.name.asTermName, mods,
def expand(id: Ident, definesLookups: Boolean) =
expandSimpleEnumCase(id.name.asTermName, mods, definesLookups,
Span(id.span.start, id.span.end, id.span.start))
}

val ids = pats.asInstanceOf[List[Ident]]
if definesEnumLookupMethods(pdef) then
ids.init.map(expand(_, false)) ::: expand(ids.last, true) :: Nil
else
ids.map(expand(_, false))
else {
val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt))
pats1 map (makePatDef(pdef, mods, _, rhs))
Expand Down
56 changes: 42 additions & 14 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ object DesugarEnums {
val Simple, Object, Class: Value = Value
}

/** Attachment containing the number of enum cases and the smallest kind that was seen so far. */
val EnumCaseCount: Property.Key[(Int, DesugarEnums.CaseKind.Value)] = Property.Key()
/** Attachment containing the number of enum cases, the smallest kind that was seen so far,
* and a list of all the value cases with their ordinals.
*/
val EnumCaseCount: Property.Key[(Int, CaseKind.Value, List[(Int, TermName)])] = Property.Key()

/** Attachment signalling that when this definition is desugared, it should add any additional
* lookup methods for enums.
*/
val DefinesEnumLookupMethods: Property.Key[Unit] = Property.Key()

/** The enumeration class that belongs to an enum case. This works no matter
* whether the case is still in the enum class or it has been transferred to the
Expand Down Expand Up @@ -122,6 +129,21 @@ object DesugarEnums {
valueOfDef :: Nil
}

private def enumLookupMethods(cases: List[(Int, TermName)])(using Context): List[Tree] =
if isJavaEnum || cases.isEmpty then Nil
else
val defaultCase =
val ord = Ident(nme.ordinal)
val err = Throw(New(TypeTree(defn.IndexOutOfBoundsException.typeRef), List(Select(ord, nme.toString_) :: Nil)))
CaseDef(ord, EmptyTree, err)
val valueCases = cases.map((i, name) =>
CaseDef(Literal(Constant(i)), EmptyTree, Ident(name))
) ::: defaultCase :: Nil
val fromOrdinalDef = DefDef(nme.fromOrdinalDollar, Nil, List(param(nme.ordinalDollar_, defn.IntType) :: Nil),
rawRef(enumClass.typeRef), Match(Ident(nme.ordinalDollar_), valueCases))
.withFlags(Synthetic | Private)
fromOrdinalDef :: Nil

/** A creation method for a value of enum type `E`, which is defined as follows:
*
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
Expand Down Expand Up @@ -256,16 +278,22 @@ object DesugarEnums {
* - scaffolding containing the necessary definitions for singleton enum cases
* unless that scaffolding was already generated by a previous call to `nextEnumKind`.
*/
def nextOrdinal(kind: CaseKind.Value)(using Context): (Int, List[Tree]) = {
val (count, seenKind) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class))
val minKind = if (kind < seenKind) kind else seenKind
ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind))
val scaffolding =
def nextOrdinal(name: Name, kind: CaseKind.Value, definesLookups: Boolean)(using Context): (Int, List[Tree]) = {
val (ordinal, seenKind, seenCases) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class, Nil))
val minKind = if kind < seenKind then kind else seenKind
val cases = name match
case name: TermName => (ordinal, name) :: seenCases
case _ => seenCases
ctx.tree.pushAttachment(EnumCaseCount, (ordinal + 1, minKind, cases))
val scaffolding0 =
if (kind >= seenKind) Nil
else if (kind == CaseKind.Object) enumScaffolding
else if (seenKind == CaseKind.Object) enumValueCreator :: Nil
else enumScaffolding :+ enumValueCreator
(count, scaffolding)
val scaffolding =
if definesLookups then scaffolding0 ::: enumLookupMethods(cases.reverse)
else scaffolding0
(ordinal, scaffolding)
}

def param(name: TermName, typ: Type)(using Context) =
Expand All @@ -286,13 +314,13 @@ object DesugarEnums {
enumLabelMeth(Literal(Constant(name)))

/** Expand a module definition representing a parameterless enum case */
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, span: Span)(using Context): Tree = {
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree = {
assert(impl.body.isEmpty)
if (!enumClass.exists) EmptyTree
else if (impl.parents.isEmpty)
expandSimpleEnumCase(name, mods, span)
expandSimpleEnumCase(name, mods, definesLookups, span)
else {
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Object, definesLookups)
val ordinalDef = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil
val enumLabelDef = enumLabelLit(name.toString)
val impl1 = cpy.Template(impl)(
Expand All @@ -305,15 +333,15 @@ object DesugarEnums {
}

/** Expand a simple enum case */
def expandSimpleEnumCase(name: TermName, mods: Modifiers, span: Span)(using Context): Tree =
def expandSimpleEnumCase(name: TermName, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree =
if (!enumClass.exists) EmptyTree
else if (enumClass.typeParams.nonEmpty) {
val parent = interpolatedEnumParent(span)
val impl = Template(emptyConstructor, parent :: Nil, Nil, EmptyValDef, Nil)
expandEnumModule(name, impl, mods, span)
expandEnumModule(name, impl, mods, definesLookups, span)
}
else {
val (tag, scaffolding) = nextOrdinal(CaseKind.Simple)
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Simple, definesLookups)
val creator = Apply(Ident(nme.DOLLAR_NEW), List(Literal(Constant(tag)), Literal(Constant(name.toString))))
val vdef = ValDef(name, enumClassRef, creator).withMods(mods.withAddedFlags(EnumValue, span))
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ object StdNames {
val using: N = "using"
val value: N = "value"
val valueOf : N = "valueOf"
val fromOrdinalDollar: N = "$fromOrdinal"
val values: N = "values"
val view_ : N = "view"
val wait_ : N = "wait"
Expand All @@ -623,6 +624,7 @@ object StdNames {
val WorksheetWrapper: N = "WorksheetWrapper"
val wrap: N = "wrap"
val writeReplace: N = "writeReplace"
val readResolve: N = "readResolve"
val zero: N = "zero"
val zip: N = "zip"
val nothingRuntimeClass: N = "scala.runtime.Nothing$"
Expand Down
23 changes: 16 additions & 7 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,19 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
.exists

private def hasReadResolve(clazz: ClassSymbol)(using Context): Boolean =
clazz.membersNamed(nme.readResolve)
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
.exists

private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol =
newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm

private def readResolveDef(clazz: ClassSymbol)(using Context): TermSymbol =
newSymbol(clazz, nme.readResolve, Method | Private | Synthetic,
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm

/** If this is a static object `Foo`, add the method:
*
* private def writeReplace(): AnyRef =
Expand Down Expand Up @@ -405,22 +414,22 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
/** If this is the class backing a serializable singleton enum value with base class `MyEnum`,
* and not deriving from `java.lang.Enum` add the method:
*
* private def writeReplace(): AnyRef =
* new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal)
* private def readResolve(): AnyRef =
* MyEnum.$fromOrdinal(this.ordinal)
*
* unless an implementation already exists, otherwise do nothing.
*/
def serializableEnumValueMethod(clazz: ClassSymbol)(using Context): List[Tree] =
if clazz.isEnumValueImplementation
&& !clazz.derivesFrom(defn.JavaEnumClass)
&& clazz.isSerializable
&& !hasWriteReplace(clazz)
&& !hasReadResolve(clazz)
then
List(
DefDef(writeReplaceDef(clazz),
_ => New(defn.EnumValueSerializationProxyClass.typeRef,
defn.EnumValueSerializationProxyConstructor,
List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied)))
DefDef(readResolveDef(clazz),
_ => ref(clazz.owner.owner.sourceModule)
.select(nme.fromOrdinalDollar)
.appliedTo(This(clazz).select(nme.ordinal).ensureApplied))
.withSpan(ctx.owner.span.focus))
else
Nil
Expand Down
43 changes: 35 additions & 8 deletions tests/run/enums-serialization-compat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,49 @@ import java.io._
import scala.util.Using

enum JColor extends java.lang.Enum[JColor]:
case Red
case Red // java enum has magic JVM support

enum SColor:
case Green
case Green // simple case last

enum SColorTagged[T]:
case Blue extends SColorTagged[Unit]
case Blue extends SColorTagged[Unit]
case Rgb(r: Byte, g: Byte, b: Byte) extends SColorTagged[(Byte, Byte, Byte)] // mixing pattern kinds
case Indigo extends SColorTagged[Unit]
case Cmyk(c: Byte, m: Byte, y: Byte, k: Byte) extends SColorTagged[(Byte, Byte, Byte, Byte)] // class case last

enum Nucleobase:
case A,C,G,T // patdef last

enum MyClassTag[T](wrapped: Class[?]):
case IntTag extends MyClassTag[Int](classOf[Int])
case UnitTag extends MyClassTag[Unit](classOf[Unit]) // value case last

extension (ref: AnyRef) def aliases(compare: AnyRef) = assert(ref eq compare, compare)

@main def Test = Using.Manager({ use =>
val buf = use(ByteArrayOutputStream())
val out = use(ObjectOutputStream(buf))
Seq(JColor.Red, SColor.Green, SColorTagged.Blue).foreach(out.writeObject)
Seq(JColor.Red, SColor.Green, SColorTagged.Blue, SColorTagged.Indigo).foreach(out.writeObject)
Seq(Nucleobase.A, Nucleobase.C, Nucleobase.G, Nucleobase.T).foreach(out.writeObject)
Seq(MyClassTag.IntTag, MyClassTag.UnitTag).foreach(out.writeObject)
val read = use(ByteArrayInputStream(buf.toByteArray))
val in = use(ObjectInputStream(read))
val Seq(Red @ _, Green @ _, Blue @ _) = (1 to 3).map(_ => in.readObject)
assert(Red eq JColor.Red, JColor.Red)
assert(Green eq SColor.Green, SColor.Green)
assert(Blue eq SColorTagged.Blue, SColorTagged.Blue)

val Seq(Red @ _, Green @ _, Blue @ _, Indigo @ _) = (1 to 4).map(_ => in.readObject)
Red aliases JColor.Red
Green aliases SColor.Green
Blue aliases SColorTagged.Blue
Indigo aliases SColorTagged.Indigo

val Seq(A @ _, C @ _, G @ _, T @ _) = (1 to 4).map(_ => in.readObject)
A aliases Nucleobase.A
C aliases Nucleobase.C
G aliases Nucleobase.G
T aliases Nucleobase.T

val Seq(IntTag @ _, UnitTag @ _) = (1 to 2).map(_ => in.readObject)
IntTag aliases MyClassTag.IntTag
UnitTag aliases MyClassTag.UnitTag

}).get
16 changes: 15 additions & 1 deletion tests/semanticdb/metac.expect
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ Schema => SemanticDB v4
Uri => Enums.scala
Text => empty
Language => Scala
Symbols => 169 entries
Symbols => 183 entries
Occurrences => 203 entries

Symbols:
Expand All @@ -651,6 +651,8 @@ _empty_/Enums.Coin#`<init>`(). => primary ctor <init>
_empty_/Enums.Coin#`<init>`().(value) => param value
_empty_/Enums.Coin#value. => val method value
_empty_/Enums.Coin. => final object Coin
_empty_/Enums.Coin.$fromOrdinal(). => method $fromOrdinal
_empty_/Enums.Coin.$fromOrdinal().(_$ordinal) => param _$ordinal
_empty_/Enums.Coin.$values. => val method $values
_empty_/Enums.Coin.Dime. => case val static enum method Dime
_empty_/Enums.Coin.Dollar. => case val static enum method Dollar
Expand All @@ -663,6 +665,8 @@ _empty_/Enums.Coin.values(). => method values
_empty_/Enums.Colour# => abstract sealed enum class Colour
_empty_/Enums.Colour#`<init>`(). => primary ctor <init>
_empty_/Enums.Colour. => final object Colour
_empty_/Enums.Colour.$fromOrdinal(). => method $fromOrdinal
_empty_/Enums.Colour.$fromOrdinal().(_$ordinal) => param _$ordinal
_empty_/Enums.Colour.$new(). => method $new
_empty_/Enums.Colour.$new().($name) => param $name
_empty_/Enums.Colour.$new().(_$ordinal) => param _$ordinal
Expand All @@ -676,6 +680,8 @@ _empty_/Enums.Colour.values(). => method values
_empty_/Enums.Directions# => abstract sealed enum class Directions
_empty_/Enums.Directions#`<init>`(). => primary ctor <init>
_empty_/Enums.Directions. => final object Directions
_empty_/Enums.Directions.$fromOrdinal(). => method $fromOrdinal
_empty_/Enums.Directions.$fromOrdinal().(_$ordinal) => param _$ordinal
_empty_/Enums.Directions.$new(). => method $new
_empty_/Enums.Directions.$new().($name) => param $name
_empty_/Enums.Directions.$new().(_$ordinal) => param _$ordinal
Expand All @@ -691,6 +697,8 @@ _empty_/Enums.Maybe# => abstract sealed enum class Maybe
_empty_/Enums.Maybe#[A] => covariant typeparam A
_empty_/Enums.Maybe#`<init>`(). => primary ctor <init>
_empty_/Enums.Maybe. => final object Maybe
_empty_/Enums.Maybe.$fromOrdinal(). => method $fromOrdinal
_empty_/Enums.Maybe.$fromOrdinal().(_$ordinal) => param _$ordinal
_empty_/Enums.Maybe.$values. => val method $values
_empty_/Enums.Maybe.Just# => final case enum class Just
_empty_/Enums.Maybe.Just#[A] => typeparam A
Expand Down Expand Up @@ -743,6 +751,8 @@ _empty_/Enums.Planet.values(). => method values
_empty_/Enums.Suits# => abstract sealed enum class Suits
_empty_/Enums.Suits#`<init>`(). => primary ctor <init>
_empty_/Enums.Suits. => final object Suits
_empty_/Enums.Suits.$fromOrdinal(). => method $fromOrdinal
_empty_/Enums.Suits.$fromOrdinal().(_$ordinal) => param _$ordinal
_empty_/Enums.Suits.$new(). => method $new
_empty_/Enums.Suits.$new().($name) => param $name
_empty_/Enums.Suits.$new().(_$ordinal) => param _$ordinal
Expand All @@ -763,6 +773,8 @@ _empty_/Enums.Tag# => abstract sealed enum class Tag
_empty_/Enums.Tag#[A] => typeparam A
_empty_/Enums.Tag#`<init>`(). => primary ctor <init>
_empty_/Enums.Tag. => final object Tag
_empty_/Enums.Tag.$fromOrdinal(). => method $fromOrdinal
_empty_/Enums.Tag.$fromOrdinal().(_$ordinal) => param _$ordinal
_empty_/Enums.Tag.$values. => val method $values
_empty_/Enums.Tag.BooleanTag. => case val static enum method BooleanTag
_empty_/Enums.Tag.IntTag. => case val static enum method IntTag
Expand All @@ -772,6 +784,8 @@ _empty_/Enums.Tag.values(). => method values
_empty_/Enums.WeekDays# => abstract sealed enum class WeekDays
_empty_/Enums.WeekDays#`<init>`(). => primary ctor <init>
_empty_/Enums.WeekDays. => final object WeekDays
_empty_/Enums.WeekDays.$fromOrdinal(). => method $fromOrdinal
_empty_/Enums.WeekDays.$fromOrdinal().(_$ordinal) => param _$ordinal
_empty_/Enums.WeekDays.$new(). => method $new
_empty_/Enums.WeekDays.$new().($name) => param $name
_empty_/Enums.WeekDays.$new().(_$ordinal) => param _$ordinal
Expand Down