diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index fdf49447713f..766d1940b400 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -476,6 +476,8 @@ object desugar { val (enumCases, enumStats) = stats.partition(DesugarEnums.isEnumCase) if (enumCases.isEmpty) report.error(EnumerationsShouldNotBeEmpty(cdef), namePos) + else + enumCases.last.pushAttachment(DesugarEnums.DefinesEnumLookupMethods, ()) val enumCompanionRef = TermRefTree() val enumImport = Import(enumCompanionRef, enumCases.flatMap(caseIds).map(ImportSelector(_))) @@ -568,7 +570,7 @@ object desugar { // Note: copy default parameters need @uncheckedVariance; see // neg/t1843-variances.scala for a test case. The test would give // two errors without @uncheckedVariance, one of them spurious. - val caseClassMeths = { + val (caseClassMeths, enumScaffolding) = { def syntheticProperty(name: TermName, tpt: Tree, rhs: Tree) = DefDef(name, Nil, Nil, tpt, rhs).withMods(synthetic) @@ -586,9 +588,11 @@ object desugar { yield syntheticProperty(selName, caseParams(i).tpt, Select(This(EmptyTypeIdent), caseParams(i).name)) - def enumMeths = - if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: enumLabelLit(className.toString) :: Nil - else Nil + def enumCaseMeths = + if isEnumCase then + val (ordinal, scaffolding) = nextOrdinal(className, CaseKind.Class, definesEnumLookupMethods(cdef)) + (ordinalMethLit(ordinal) :: enumLabelLit(className.toString) :: Nil, scaffolding) + else (Nil, Nil) def copyMeths = { val hasRepeatedParam = constrVparamss.exists(_.exists { case ValDef(_, tpt, _) => isRepeated(tpt) @@ -607,8 +611,9 @@ object desugar { } if (isCaseClass) - copyMeths ::: enumMeths ::: productElemMeths - else Nil + val (enumMeths, enumScaffolding) = enumCaseMeths + (copyMeths ::: enumMeths ::: productElemMeths, enumScaffolding) + else (Nil, Nil) } var parents1 = parents @@ -809,7 +814,7 @@ object desugar { case _ => } - flatTree(cdef1 :: companions ::: implicitWrappers) + flatTree(cdef1 :: companions ::: implicitWrappers ::: enumScaffolding) }.reporting(i"desugared: $result", Printers.desugar) /** Expand @@ -862,7 +867,7 @@ object desugar { else if (isEnumCase) { typeParamIsReferenced(enumClass.typeParams, Nil, Nil, impl.parents) // used to check there are no illegal references to enum's type parameters in parents - expandEnumModule(moduleName, impl, mods, mdef.span) + expandEnumModule(moduleName, impl, mods, definesEnumLookupMethods(mdef), mdef.span) } else { val clsName = moduleName.moduleClassName @@ -990,6 +995,12 @@ object desugar { private def inventTypeName(tree: Tree)(using Context): String = typeNameExtractor("", tree) + /**This will check if this def tree is marked to define enum lookup methods, + * this is not recommended to call more than once per tree + */ + private def definesEnumLookupMethods(ddef: DefTree): Boolean = + ddef.removeAttachment(DefinesEnumLookupMethods).isDefined + /** val p1, ..., pN: T = E * ==> * makePatDef[[val p1: T1 = E]]; ...; makePatDef[[val pN: TN = E]] @@ -1001,11 +1012,15 @@ object desugar { def patDef(pdef: PatDef)(using Context): Tree = flatTree { val PatDef(mods, pats, tpt, rhs) = pdef if (mods.isEnumCase) - pats map { - case id: Ident => - expandSimpleEnumCase(id.name.asTermName, mods, + def expand(id: Ident, definesLookups: Boolean) = + expandSimpleEnumCase(id.name.asTermName, mods, definesLookups, Span(id.span.start, id.span.end, id.span.start)) - } + + val ids = pats.asInstanceOf[List[Ident]] + if definesEnumLookupMethods(pdef) then + ids.init.map(expand(_, false)) ::: expand(ids.last, true) :: Nil + else + ids.map(expand(_, false)) else { val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt)) pats1 map (makePatDef(pdef, mods, _, rhs)) diff --git a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala index fc9c843f0e3e..6a2fa80d6e66 100644 --- a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala +++ b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala @@ -20,8 +20,15 @@ object DesugarEnums { val Simple, Object, Class: Value = Value } - /** Attachment containing the number of enum cases and the smallest kind that was seen so far. */ - val EnumCaseCount: Property.Key[(Int, DesugarEnums.CaseKind.Value)] = Property.Key() + /** Attachment containing the number of enum cases, the smallest kind that was seen so far, + * and a list of all the value cases with their ordinals. + */ + val EnumCaseCount: Property.Key[(Int, CaseKind.Value, List[(Int, TermName)])] = Property.Key() + + /** Attachment signalling that when this definition is desugared, it should add any additional + * lookup methods for enums. + */ + val DefinesEnumLookupMethods: Property.Key[Unit] = Property.Key() /** The enumeration class that belongs to an enum case. This works no matter * whether the case is still in the enum class or it has been transferred to the @@ -122,6 +129,21 @@ object DesugarEnums { valueOfDef :: Nil } + private def enumLookupMethods(cases: List[(Int, TermName)])(using Context): List[Tree] = + if isJavaEnum || cases.isEmpty then Nil + else + val defaultCase = + val ord = Ident(nme.ordinal) + val err = Throw(New(TypeTree(defn.IndexOutOfBoundsException.typeRef), List(Select(ord, nme.toString_) :: Nil))) + CaseDef(ord, EmptyTree, err) + val valueCases = cases.map((i, name) => + CaseDef(Literal(Constant(i)), EmptyTree, Ident(name)) + ) ::: defaultCase :: Nil + val fromOrdinalDef = DefDef(nme.fromOrdinalDollar, Nil, List(param(nme.ordinalDollar_, defn.IntType) :: Nil), + rawRef(enumClass.typeRef), Match(Ident(nme.ordinalDollar_), valueCases)) + .withFlags(Synthetic | Private) + fromOrdinalDef :: Nil + /** A creation method for a value of enum type `E`, which is defined as follows: * * private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue { @@ -256,16 +278,22 @@ object DesugarEnums { * - scaffolding containing the necessary definitions for singleton enum cases * unless that scaffolding was already generated by a previous call to `nextEnumKind`. */ - def nextOrdinal(kind: CaseKind.Value)(using Context): (Int, List[Tree]) = { - val (count, seenKind) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class)) - val minKind = if (kind < seenKind) kind else seenKind - ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind)) - val scaffolding = + def nextOrdinal(name: Name, kind: CaseKind.Value, definesLookups: Boolean)(using Context): (Int, List[Tree]) = { + val (ordinal, seenKind, seenCases) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class, Nil)) + val minKind = if kind < seenKind then kind else seenKind + val cases = name match + case name: TermName => (ordinal, name) :: seenCases + case _ => seenCases + ctx.tree.pushAttachment(EnumCaseCount, (ordinal + 1, minKind, cases)) + val scaffolding0 = if (kind >= seenKind) Nil else if (kind == CaseKind.Object) enumScaffolding else if (seenKind == CaseKind.Object) enumValueCreator :: Nil else enumScaffolding :+ enumValueCreator - (count, scaffolding) + val scaffolding = + if definesLookups then scaffolding0 ::: enumLookupMethods(cases.reverse) + else scaffolding0 + (ordinal, scaffolding) } def param(name: TermName, typ: Type)(using Context) = @@ -286,13 +314,13 @@ object DesugarEnums { enumLabelMeth(Literal(Constant(name))) /** Expand a module definition representing a parameterless enum case */ - def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, span: Span)(using Context): Tree = { + def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree = { assert(impl.body.isEmpty) if (!enumClass.exists) EmptyTree else if (impl.parents.isEmpty) - expandSimpleEnumCase(name, mods, span) + expandSimpleEnumCase(name, mods, definesLookups, span) else { - val (tag, scaffolding) = nextOrdinal(CaseKind.Object) + val (tag, scaffolding) = nextOrdinal(name, CaseKind.Object, definesLookups) val ordinalDef = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil val enumLabelDef = enumLabelLit(name.toString) val impl1 = cpy.Template(impl)( @@ -305,15 +333,15 @@ object DesugarEnums { } /** Expand a simple enum case */ - def expandSimpleEnumCase(name: TermName, mods: Modifiers, span: Span)(using Context): Tree = + def expandSimpleEnumCase(name: TermName, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree = if (!enumClass.exists) EmptyTree else if (enumClass.typeParams.nonEmpty) { val parent = interpolatedEnumParent(span) val impl = Template(emptyConstructor, parent :: Nil, Nil, EmptyValDef, Nil) - expandEnumModule(name, impl, mods, span) + expandEnumModule(name, impl, mods, definesLookups, span) } else { - val (tag, scaffolding) = nextOrdinal(CaseKind.Simple) + val (tag, scaffolding) = nextOrdinal(name, CaseKind.Simple, definesLookups) val creator = Apply(Ident(nme.DOLLAR_NEW), List(Literal(Constant(tag)), Literal(Constant(name.toString)))) val vdef = ValDef(name, enumClassRef, creator).withMods(mods.withAddedFlags(EnumValue, span)) flatTree(scaffolding ::: vdef :: Nil).withSpan(span) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 7e69d9488675..4670ee104e3a 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -615,6 +615,7 @@ object StdNames { val using: N = "using" val value: N = "value" val valueOf : N = "valueOf" + val fromOrdinalDollar: N = "$fromOrdinal" val values: N = "values" val view_ : N = "view" val wait_ : N = "wait" @@ -623,6 +624,7 @@ object StdNames { val WorksheetWrapper: N = "WorksheetWrapper" val wrap: N = "wrap" val writeReplace: N = "writeReplace" + val readResolve: N = "readResolve" val zero: N = "zero" val zip: N = "zip" val nothingRuntimeClass: N = "scala.runtime.Nothing$" diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index 4910949ead1a..6218818150b0 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -373,10 +373,19 @@ class SyntheticMembers(thisPhase: DenotTransformer) { .filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false)) .exists + private def hasReadResolve(clazz: ClassSymbol)(using Context): Boolean = + clazz.membersNamed(nme.readResolve) + .filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false)) + .exists + private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol = newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic, MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm + private def readResolveDef(clazz: ClassSymbol)(using Context): TermSymbol = + newSymbol(clazz, nme.readResolve, Method | Private | Synthetic, + MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm + /** If this is a static object `Foo`, add the method: * * private def writeReplace(): AnyRef = @@ -405,8 +414,8 @@ class SyntheticMembers(thisPhase: DenotTransformer) { /** If this is the class backing a serializable singleton enum value with base class `MyEnum`, * and not deriving from `java.lang.Enum` add the method: * - * private def writeReplace(): AnyRef = - * new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal) + * private def readResolve(): AnyRef = + * MyEnum.$fromOrdinal(this.ordinal) * * unless an implementation already exists, otherwise do nothing. */ @@ -414,13 +423,13 @@ class SyntheticMembers(thisPhase: DenotTransformer) { if clazz.isEnumValueImplementation && !clazz.derivesFrom(defn.JavaEnumClass) && clazz.isSerializable - && !hasWriteReplace(clazz) + && !hasReadResolve(clazz) then List( - DefDef(writeReplaceDef(clazz), - _ => New(defn.EnumValueSerializationProxyClass.typeRef, - defn.EnumValueSerializationProxyConstructor, - List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied))) + DefDef(readResolveDef(clazz), + _ => ref(clazz.owner.owner.sourceModule) + .select(nme.fromOrdinalDollar) + .appliedTo(This(clazz).select(nme.ordinal).ensureApplied)) .withSpan(ctx.owner.span.focus)) else Nil diff --git a/library/src/scala/runtime/EnumValueSerializationProxy.java b/library/src-non-bootstrapped/scala/runtime/EnumValueSerializationProxy.java similarity index 100% rename from library/src/scala/runtime/EnumValueSerializationProxy.java rename to library/src-non-bootstrapped/scala/runtime/EnumValueSerializationProxy.java diff --git a/tests/run/enums-serialization-compat.scala b/tests/run/enums-serialization-compat.scala index 7224b67d5f80..940e726c05a0 100644 --- a/tests/run/enums-serialization-compat.scala +++ b/tests/run/enums-serialization-compat.scala @@ -2,22 +2,49 @@ import java.io._ import scala.util.Using enum JColor extends java.lang.Enum[JColor]: - case Red + case Red // java enum has magic JVM support enum SColor: - case Green + case Green // simple case last enum SColorTagged[T]: - case Blue extends SColorTagged[Unit] + case Blue extends SColorTagged[Unit] + case Rgb(r: Byte, g: Byte, b: Byte) extends SColorTagged[(Byte, Byte, Byte)] // mixing pattern kinds + case Indigo extends SColorTagged[Unit] + case Cmyk(c: Byte, m: Byte, y: Byte, k: Byte) extends SColorTagged[(Byte, Byte, Byte, Byte)] // class case last + +enum Nucleobase: + case A,C,G,T // patdef last + +enum MyClassTag[T](wrapped: Class[?]): + case IntTag extends MyClassTag[Int](classOf[Int]) + case UnitTag extends MyClassTag[Unit](classOf[Unit]) // value case last + +extension (ref: AnyRef) def aliases(compare: AnyRef) = assert(ref eq compare, compare) @main def Test = Using.Manager({ use => val buf = use(ByteArrayOutputStream()) val out = use(ObjectOutputStream(buf)) - Seq(JColor.Red, SColor.Green, SColorTagged.Blue).foreach(out.writeObject) + Seq(JColor.Red, SColor.Green, SColorTagged.Blue, SColorTagged.Indigo).foreach(out.writeObject) + Seq(Nucleobase.A, Nucleobase.C, Nucleobase.G, Nucleobase.T).foreach(out.writeObject) + Seq(MyClassTag.IntTag, MyClassTag.UnitTag).foreach(out.writeObject) val read = use(ByteArrayInputStream(buf.toByteArray)) val in = use(ObjectInputStream(read)) - val Seq(Red @ _, Green @ _, Blue @ _) = (1 to 3).map(_ => in.readObject) - assert(Red eq JColor.Red, JColor.Red) - assert(Green eq SColor.Green, SColor.Green) - assert(Blue eq SColorTagged.Blue, SColorTagged.Blue) + + val Seq(Red @ _, Green @ _, Blue @ _, Indigo @ _) = (1 to 4).map(_ => in.readObject) + Red aliases JColor.Red + Green aliases SColor.Green + Blue aliases SColorTagged.Blue + Indigo aliases SColorTagged.Indigo + + val Seq(A @ _, C @ _, G @ _, T @ _) = (1 to 4).map(_ => in.readObject) + A aliases Nucleobase.A + C aliases Nucleobase.C + G aliases Nucleobase.G + T aliases Nucleobase.T + + val Seq(IntTag @ _, UnitTag @ _) = (1 to 2).map(_ => in.readObject) + IntTag aliases MyClassTag.IntTag + UnitTag aliases MyClassTag.UnitTag + }).get diff --git a/tests/semanticdb/metac.expect b/tests/semanticdb/metac.expect index 675a0b1f1288..6bbfad8b94d9 100644 --- a/tests/semanticdb/metac.expect +++ b/tests/semanticdb/metac.expect @@ -641,7 +641,7 @@ Schema => SemanticDB v4 Uri => Enums.scala Text => empty Language => Scala -Symbols => 169 entries +Symbols => 183 entries Occurrences => 203 entries Symbols: @@ -651,6 +651,8 @@ _empty_/Enums.Coin#``(). => primary ctor _empty_/Enums.Coin#``().(value) => param value _empty_/Enums.Coin#value. => val method value _empty_/Enums.Coin. => final object Coin +_empty_/Enums.Coin.$fromOrdinal(). => method $fromOrdinal +_empty_/Enums.Coin.$fromOrdinal().(_$ordinal) => param _$ordinal _empty_/Enums.Coin.$values. => val method $values _empty_/Enums.Coin.Dime. => case val static enum method Dime _empty_/Enums.Coin.Dollar. => case val static enum method Dollar @@ -663,6 +665,8 @@ _empty_/Enums.Coin.values(). => method values _empty_/Enums.Colour# => abstract sealed enum class Colour _empty_/Enums.Colour#``(). => primary ctor _empty_/Enums.Colour. => final object Colour +_empty_/Enums.Colour.$fromOrdinal(). => method $fromOrdinal +_empty_/Enums.Colour.$fromOrdinal().(_$ordinal) => param _$ordinal _empty_/Enums.Colour.$new(). => method $new _empty_/Enums.Colour.$new().($name) => param $name _empty_/Enums.Colour.$new().(_$ordinal) => param _$ordinal @@ -676,6 +680,8 @@ _empty_/Enums.Colour.values(). => method values _empty_/Enums.Directions# => abstract sealed enum class Directions _empty_/Enums.Directions#``(). => primary ctor _empty_/Enums.Directions. => final object Directions +_empty_/Enums.Directions.$fromOrdinal(). => method $fromOrdinal +_empty_/Enums.Directions.$fromOrdinal().(_$ordinal) => param _$ordinal _empty_/Enums.Directions.$new(). => method $new _empty_/Enums.Directions.$new().($name) => param $name _empty_/Enums.Directions.$new().(_$ordinal) => param _$ordinal @@ -691,6 +697,8 @@ _empty_/Enums.Maybe# => abstract sealed enum class Maybe _empty_/Enums.Maybe#[A] => covariant typeparam A _empty_/Enums.Maybe#``(). => primary ctor _empty_/Enums.Maybe. => final object Maybe +_empty_/Enums.Maybe.$fromOrdinal(). => method $fromOrdinal +_empty_/Enums.Maybe.$fromOrdinal().(_$ordinal) => param _$ordinal _empty_/Enums.Maybe.$values. => val method $values _empty_/Enums.Maybe.Just# => final case enum class Just _empty_/Enums.Maybe.Just#[A] => typeparam A @@ -743,6 +751,8 @@ _empty_/Enums.Planet.values(). => method values _empty_/Enums.Suits# => abstract sealed enum class Suits _empty_/Enums.Suits#``(). => primary ctor _empty_/Enums.Suits. => final object Suits +_empty_/Enums.Suits.$fromOrdinal(). => method $fromOrdinal +_empty_/Enums.Suits.$fromOrdinal().(_$ordinal) => param _$ordinal _empty_/Enums.Suits.$new(). => method $new _empty_/Enums.Suits.$new().($name) => param $name _empty_/Enums.Suits.$new().(_$ordinal) => param _$ordinal @@ -763,6 +773,8 @@ _empty_/Enums.Tag# => abstract sealed enum class Tag _empty_/Enums.Tag#[A] => typeparam A _empty_/Enums.Tag#``(). => primary ctor _empty_/Enums.Tag. => final object Tag +_empty_/Enums.Tag.$fromOrdinal(). => method $fromOrdinal +_empty_/Enums.Tag.$fromOrdinal().(_$ordinal) => param _$ordinal _empty_/Enums.Tag.$values. => val method $values _empty_/Enums.Tag.BooleanTag. => case val static enum method BooleanTag _empty_/Enums.Tag.IntTag. => case val static enum method IntTag @@ -772,6 +784,8 @@ _empty_/Enums.Tag.values(). => method values _empty_/Enums.WeekDays# => abstract sealed enum class WeekDays _empty_/Enums.WeekDays#``(). => primary ctor _empty_/Enums.WeekDays. => final object WeekDays +_empty_/Enums.WeekDays.$fromOrdinal(). => method $fromOrdinal +_empty_/Enums.WeekDays.$fromOrdinal().(_$ordinal) => param _$ordinal _empty_/Enums.WeekDays.$new(). => method $new _empty_/Enums.WeekDays.$new().($name) => param $name _empty_/Enums.WeekDays.$new().(_$ordinal) => param _$ordinal