Skip to content

Commit 2bcd4f6

Browse files
authored
Merge pull request #9018 from dotty-staging/fix-#9011b
Fix #9011: Make single enum values inherit from Product
2 parents 484e3c6 + 74cc1b1 commit 2bcd4f6

22 files changed

+200
-42
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ object DesugarEnums {
124124

125125
/** A creation method for a value of enum type `E`, which is defined as follows:
126126
*
127-
* private def $new(_$ordinal: Int, $name: String) = new E {
127+
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
128128
* def $ordinal = $tag
129129
* override def toString = $name
130130
* $values.register(this)
@@ -135,7 +135,7 @@ object DesugarEnums {
135135
val toStringDef = toStringMeth(Ident(nme.nameDollar))
136136
val creator = New(Template(
137137
constr = emptyConstructor,
138-
parents = enumClassRef :: Nil,
138+
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
139139
derived = Nil,
140140
self = EmptyValDef,
141141
body = List(ordinalDef, toStringDef) ++ registerCall
@@ -286,7 +286,9 @@ object DesugarEnums {
286286
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
287287
val ordinalDef = ordinalMethLit(tag)
288288
val toStringDef = toStringMethLit(name.toString)
289-
val impl1 = cpy.Template(impl)(body = List(ordinalDef, toStringDef) ++ registerCall)
289+
val impl1 = cpy.Template(impl)(
290+
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
291+
body = List(ordinalDef, toStringDef) ++ registerCall)
290292
.withAttachment(ExtendsSingletonMirror, ())
291293
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
292294
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
450450
def rootDot(name: Name)(implicit src: SourceFile): Select = Select(Ident(nme.ROOTPKG), name)
451451
def scalaDot(name: Name)(implicit src: SourceFile): Select = Select(rootDot(nme.scala), name)
452452
def scalaAnnotationDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.annotation), name)
453+
def scalaRuntimeDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.runtime), name)
453454
def scalaUnit(implicit src: SourceFile): Select = scalaDot(tpnme.Unit)
454455
def scalaAny(implicit src: SourceFile): Select = scalaDot(tpnme.Any)
455456
def javaDotLangDot(name: Name)(implicit src: SourceFile): Select = Select(Select(Ident(nme.java), nme.lang), name)

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ trait ConstraintHandling[AbstractContext] {
300300
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
301301
* 2. If `inst` is a union type, approximate the union type from above by an intersection
302302
* of all common base types, provided the result is a subtype of `bound`.
303+
* 3. (currently not enabled, see #9028) If `inst` is an intersection with some restricted base types, drop
304+
* the restricted base types from the intersection, provided the result is a subtype of `bound`.
303305
*
304306
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
305307
* Also, if the result of these widenings is a TypeRef to a module class,
@@ -309,26 +311,48 @@ trait ConstraintHandling[AbstractContext] {
309311
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
310312
* as those could leak the annotation to users (see run/inferred-repeated-result).
311313
*/
312-
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type = {
313-
def widenOr(tp: Type) = {
314+
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =
315+
316+
def isRestricted(tp: Type) = tp.typeSymbol == defn.EnumValueClass // for now, to be generalized later
317+
318+
def dropRestricted(tp: Type): Type = tp.dealias match
319+
case tpd @ AndType(tp1, tp2) =>
320+
if isRestricted(tp1) then tp2
321+
else if isRestricted(tp2) then tp1
322+
else
323+
val tpw = tpd.derivedAndType(dropRestricted(tp1), dropRestricted(tp2))
324+
if tpw ne tpd then tpw else tp
325+
case _ =>
326+
tp
327+
328+
def widenRestricted(tp: Type) =
329+
val tpw = dropRestricted(tp)
330+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
331+
332+
def widenOr(tp: Type) =
314333
val tpw = tp.widenUnion
315334
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
316-
}
317-
def widenSingle(tp: Type) = {
335+
336+
def widenSingle(tp: Type) =
318337
val tpw = tp.widenSingletons
319338
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
320-
}
339+
321340
def isSingleton(tp: Type): Boolean = tp match
322341
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
323342
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)
343+
324344
val wideInst =
325-
if isSingleton(bound) then inst else widenOr(widenSingle(inst))
345+
if isSingleton(bound) then inst
346+
else /*widenRestricted*/(widenOr(widenSingle(inst)))
347+
// widenRestricted is currently not called since it's special cased in `dropEnumValue`
348+
// in `Namer`. It's left in here in case we want to generalize the scheme to other
349+
// "protected inheritance" classes.
326350
wideInst match
327351
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
328352
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
329353
case _ =>
330354
wideInst.dropRepeatedAnnot
331-
}
355+
end widenInferred
332356

333357
/** The instance type of `param` in the current constraint (which contains `param`).
334358
* If `fromBelow` is true, the instance type is the lub of the parameter's

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ class Definitions {
639639
@tu lazy val EnumClass: ClassSymbol = ctx.requiredClass("scala.Enum")
640640
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
641641

642+
@tu lazy val EnumValueClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValue")
642643
@tu lazy val EnumValuesClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValues")
643644
@tu lazy val ProductClass: ClassSymbol = ctx.requiredClass("scala.Product")
644645
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)

compiler/src/dotty/tools/dotc/core/Flags.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ object Flags {
438438
* TODO: Should check that FromStartFlags do not change in completion
439439
*/
440440
val FromStartFlags: FlagSet = commonFlags(
441-
Module, Package, Deferred, Method, Case,
441+
Module, Package, Deferred, Method, Case, Enum,
442442
HigherKinded, Param, ParamAccessor,
443443
Scala2SpecialFlags, MutableOrOpen, Opaque, Touched, JavaStatic,
444444
OuterOrCovariant, LabelOrContravariant, CaseAccessor,

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,15 +357,14 @@ object StdNames {
357357
val CAP: N = "CAP"
358358
val Constant: N = "Constant"
359359
val ConstantType: N = "ConstantType"
360-
val doubleHash: N = "doubleHash"
360+
val EnumValue: N = "EnumValue"
361361
val ExistentialTypeTree: N = "ExistentialTypeTree"
362362
val Flag : N = "Flag"
363363
val floatHash: N = "floatHash"
364364
val Ident: N = "Ident"
365365
val Import: N = "Import"
366366
val Literal: N = "Literal"
367367
val LiteralAnnotArg: N = "LiteralAnnotArg"
368-
val longHash: N = "longHash"
369368
val MatchCase: N = "MatchCase"
370369
val MirroredElemTypes: N = "MirroredElemTypes"
371370
val MirroredElemLabels: N = "MirroredElemLabels"
@@ -443,6 +442,7 @@ object StdNames {
443442
val delayedInitArg: N = "delayedInit$body"
444443
val derived: N = "derived"
445444
val derives: N = "derives"
445+
val doubleHash: N = "doubleHash"
446446
val drop: N = "drop"
447447
val dynamics: N = "dynamics"
448448
val elem: N = "elem"
@@ -505,6 +505,7 @@ object StdNames {
505505
val language: N = "language"
506506
val length: N = "length"
507507
val lengthCompare: N = "lengthCompare"
508+
val longHash: N = "longHash"
508509
val macroThis : N = "_this"
509510
val macroContext : N = "c"
510511
val main: N = "main"

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,19 @@ class Namer { typer: Typer =>
14391439
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
14401440
def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable)
14411441

1442+
def isEnumValue(tp: Type) = tp.typeSymbol == defn.EnumValueClass
1443+
1444+
// Drop EnumValue parents from inferred types of enum constants
1445+
def dropEnumValue(tp: Type): Type = tp.dealias match
1446+
case tpd @ AndType(tp1, tp2) =>
1447+
if isEnumValue(tp1) then tp2
1448+
else if isEnumValue(tp2) then tp1
1449+
else
1450+
val tpw = tpd.derivedAndType(dropEnumValue(tp1), dropEnumValue(tp2))
1451+
if tpw ne tpd then tpw else tp
1452+
case _ =>
1453+
tp
1454+
14421455
// Widen rhs type and eliminate `|' but keep ConstantTypes if
14431456
// definition is inline (i.e. final in Scala2) and keep module singleton types
14441457
// instead of widening to the underlying module class types.
@@ -1447,7 +1460,9 @@ class Namer { typer: Typer =>
14471460
def widenRhs(tp: Type): Type =
14481461
tp.widenTermRefExpr.simplified match
14491462
case ctp: ConstantType if isInlineVal => ctp
1450-
case tp => ctx.typeComparer.widenInferred(tp, rhsProto)
1463+
case tp =>
1464+
val tp1 = ctx.typeComparer.widenInferred(tp, rhsProto)
1465+
if sym.is(Enum) then dropEnumValue(tp1) else tp1
14511466

14521467
// Replace aliases to Unit by Unit itself. If we leave the alias in
14531468
// it would be erased to BoxedUnit.

docs/docs/reference/enums/desugarEnums.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ map into `case class`es or `val`s.
126126
where `n` is the ordinal number of the case in the companion object,
127127
starting from 0. The statement `$values.register(this)` registers the value
128128
as one of the `values` of the enumeration (see below). `$values` is a
129-
compiler-defined private value in the companion object.
129+
compiler-defined private value in the companion object. The anonymous class also
130+
implements the abstract `Product` methods that it inherits from `Enum`.
131+
130132

131133
It is an error if a value case refers to a type parameter of the enclosing `enum`
132134
in a type argument of `<parents>`.
@@ -178,6 +180,7 @@ Companion objects of enumerations that contain at least one simple case define i
178180
}
179181
```
180182

183+
The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
181184
The `$ordinal` method above is used to generate the `ordinal` method if the enum does not extend a `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it.
182185

183186
### Scopes for Enum Cases

docs/docs/reference/enums/enums.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ If you want to use the Scala-defined enums as Java enums, you can do so by exten
9595
enum Color extends java.lang.Enum[Color] { case Red, Green, Blue }
9696
```
9797

98-
The type parameter comes from the Java enum [definition](https://docs.oracle.com/javase/8/docs/api/index.html?java/lang/Enum.html) and should be the same as the type of the enum.
98+
The type parameter comes from the Java enum [definition](https://docs.oracle.com/javase/8/docs/api/index.html?java/lang/Enum.html) and should be the same as the type of the enum.
9999
There is no need to provide constructor arguments (as defined in the Java API docs) to `java.lang.Enum` when extending it – the compiler will generate them automatically.
100100

101101
After defining `Color` like that, you can use it like you would a Java enum:
@@ -116,7 +116,7 @@ This trait defines a single public method, `ordinal`:
116116
package scala
117117

118118
/** A base trait of all enum classes */
119-
trait Enum {
119+
trait Enum extends Product with Serializable {
120120

121121
/** A number uniquely identifying a case of an enum */
122122
def ordinal: Int
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package scala
2+
3+
/** A base trait of all enum classes */
4+
trait Enum extends Product, Serializable:
5+
6+
/** A number uniquely identifying a case of an enum */
7+
def ordinal: Int
8+
protected def $ordinal: Int
9+
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
package scala
22

33
/** A base trait of all enum classes */
4-
trait Enum {
4+
trait Enum:
55

66
/** A number uniquely identifying a case of an enum */
77
def ordinal: Int
88
protected def $ordinal: Int
9-
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package scala.runtime
2+
3+
trait EnumValue extends Product, Serializable:
4+
override def canEqual(that: Any) = this eq that.asInstanceOf[AnyRef]
5+
override def productArity: Int = 0
6+
override def productPrefix: String = toString
7+
override def productElement(n: Int): Any =
8+
throw IndexOutOfBoundsException(n.toString)
9+
override def productElementName(n: Int): String =
10+
throw IndexOutOfBoundsException(n.toString)

tests/fuzzy/b82054893e0db44e31ae82d696c19c1fbc7be55c.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
object main {
1+
object _main {
22
def i0 = {
33
class i1 {
44
private[i0] var i2: _ > 0 private def i3: List[Int]

tests/neg/enumvalues.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
enum Color:
2+
case Red, Green, Blue
3+
4+
enum Option[+T]:
5+
case None extends Option[Nothing]
6+
7+
import scala.runtime.EnumValue
8+
9+
@main def Test(c: Boolean) =
10+
// Verify that enum constants don't leak the scala.runtime.EnumValue trait
11+
val x: EnumValue = if c then Color.Red else Color.Blue // error // error
12+
val y: EnumValue = Color.Green // error
13+
val z: EnumValue = Option.None // error
14+
15+

tests/pos/enum-List-control.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@ abstract sealed class List[T] extends Enum
22
object List {
33
final class Cons[T](x: T, xs: List[T]) extends List[T] {
44
def $ordinal = 0
5+
def canEqual(that: Any): Boolean = that.isInstanceOf[Cons[_]]
6+
def productArity: Int = 2
7+
def productElement(n: Int): Any = n match
8+
case 0 => x
9+
case 1 => xs
510
}
611
object Cons {
712
def apply[T](x: T, xs: List[T]): List[T] = new Cons(x, xs)
813
}
9-
final class Nil[T]() extends List[T] {
14+
final class Nil[T]() extends List[T], runtime.EnumValue {
1015
def $ordinal = 1
1116
}
1217
object Nil {

tests/pos/localmodules.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package test;
22

3-
object main {
3+
object _main {
44

55
class a {
66

tests/pos/t0002.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
object main {
1+
object _main {
22
def main(args: Array[String]) = {
33
val b = true;
44
while (b == true) { }

tests/pos/t789.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
object main { // don't do this at home
1+
object _main { // don't do this at home
22

33
trait Impl
44

tests/pos/typealiases.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ trait Test[T] {
1111
def check2[S](xs: Array[S], c: Check[S]) = c(xs)
1212
}
1313

14-
object main extends Test[Int] {
14+
object _main extends Test[Int] {
1515
val pair1 = (1,1)
1616

1717
implicit def topair(x: Int): Tuple2[Int, Int] = (x,x)

tests/run/i9011.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
enum Opt[+T] derives Eq:
2+
case Sm(t: T)
3+
case Nn
4+
5+
import scala.deriving._
6+
import scala.compiletime.{erasedValue, summonInline}
7+
8+
trait Eq[T] {
9+
def eqv(x: T, y: T): Boolean
10+
}
11+
12+
object Eq {
13+
given Eq[Int] {
14+
def eqv(x: Int, y: Int) = x == y
15+
}
16+
17+
inline def summonAll[T <: Tuple]: List[Eq[_]] = inline erasedValue[T] match {
18+
case _: Unit => Nil
19+
case _: (t *: ts) => summonInline[Eq[t]] :: summonAll[ts]
20+
}
21+
22+
def check(elem: Eq[_])(x: Any, y: Any): Boolean =
23+
elem.asInstanceOf[Eq[Any]].eqv(x, y)
24+
25+
def iterator[T](p: T) = p.asInstanceOf[Product].productIterator
26+
27+
def eqSum[T](s: Mirror.SumOf[T], elems: List[Eq[_]]): Eq[T] =
28+
new Eq[T] {
29+
def eqv(x: T, y: T): Boolean = {
30+
val ordx = s.ordinal(x)
31+
(s.ordinal(y) == ordx) && check(elems(ordx))(x, y)
32+
}
33+
}
34+
35+
def eqProduct[T](p: Mirror.ProductOf[T], elems: List[Eq[_]]): Eq[T] =
36+
new Eq[T] {
37+
def eqv(x: T, y: T): Boolean =
38+
iterator(x).zip(iterator(y)).zip(elems.iterator).forall {
39+
case ((x, y), elem) => check(elem)(x, y)
40+
}
41+
}
42+
43+
inline given derived[T](using m: Mirror.Of[T]) as Eq[T] = {
44+
val elemInstances = summonAll[m.MirroredElemTypes]
45+
inline m match {
46+
case s: Mirror.SumOf[T] => eqSum(s, elemInstances)
47+
case p: Mirror.ProductOf[T] => eqProduct(p, elemInstances)
48+
}
49+
}
50+
}
51+
52+
object Test extends App {
53+
import Opt._
54+
val eqoi = summon[Eq[Opt[Int]]]
55+
assert(eqoi.eqv(Sm(23), Sm(23)))
56+
assert(eqoi.eqv(Nn, Nn))
57+
}

0 commit comments

Comments
 (0)