Skip to content

Commit b527f64

Browse files
committed
check enum parents before typing body
1 parent 2fa3d79 commit b527f64

File tree

6 files changed

+28
-30
lines changed

6 files changed

+28
-30
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ object DesugarEnums {
7878
Select(Ident(nme.DOLLAR_VALUES), name.toTermName)
7979

8080
private def registerCall(using Context): Tree =
81-
val asRaw = TypeApply(Select(This(EmptyTypeIdent), nme.asInstanceOf_), rawRef(enumClass.typeRef) :: Nil) // safe to cast due to refchecks
82-
Apply(valuesDot("register"), asRaw :: Nil)
81+
Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil)
8382

8483
/** The following lists of definitions for an enum type E:
8584
*

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ class Definitions {
655655

656656
@tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy")
657657
@tu lazy val EnumValueSerializationProxyConstructor: TermSymbol =
658-
EnumValueSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty)))
658+
EnumValueSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty), IntType))
659659

660660
@tu lazy val ProductClass: ClassSymbol = requiredClass("scala.Product")
661661
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
294294
cpy.Inlined(tree)(callTrace, transformSub(bindings), transform(expansion)(using inlineContext(call)))
295295
case templ: Template =>
296296
withNoCheckNews(templ.parents.flatMap(newPart)) {
297-
Checking.checkEnumParentOK(templ.symbol.owner)
298297
forwardParamAccessors(templ)
299298
synthMbr.addSyntheticMembers(
300299
superAcc.wrapTemplate(templ)(

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -630,16 +630,10 @@ object Checking {
630630
}
631631
}
632632

633-
/** Check that an enum case extends its enum class */
634-
def checkEnumParentOK(cls: Symbol)(using Context): Unit =
635-
val enumCase =
636-
if cls.isAllOf(EnumCase) then cls
637-
else if cls.isAnonymousClass && cls.owner.isAllOf(EnumCase) then cls.owner
638-
else NoSymbol
639-
if enumCase.exists then
640-
val enumCls = enumCase.owner.linkedClass
641-
if !cls.info.parents.exists(_.typeSymbol == enumCls) then
642-
report.error(i"enum case does not extend its enum $enumCls", enumCase.sourcePos)
633+
def optEnumCase(cls: Symbol)(using Context): Symbol =
634+
(if cls.isAllOf(EnumCase) then cls
635+
else if cls.isAnonymousClass && cls.owner.isAllOf(EnumCase) then cls.owner
636+
else NoSymbol).filter(s => s.exists && s.owner.linkedClass.derivesFrom(defn.EnumClass))
643637

644638
/** Check the inline override methods only use inline parameters if they override an inline parameter. */
645639
def checkInlineOverrideParameters(sym: Symbol)(using Context): Unit =
@@ -1092,8 +1086,9 @@ trait Checking {
10921086

10931087
/** 1. Check that all case classes that extend `scala.Enum` are `enum` cases
10941088
* 2. Check that case class `enum` cases do not extend java.lang.Enum.
1089+
* 3. Check that the firstParent derives from the declaring enum class.
10951090
*/
1096-
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = {
1091+
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, enumCase: Symbol, firstParent: Symbol)(using Context): Boolean = {
10971092
def isEnumAnonCls =
10981093
cls.isAnonymousClass &&
10991094
cls.owner.isTerm &&
@@ -1110,6 +1105,12 @@ trait Checking {
11101105
// Unlike firstParent.derivesFrom(defn.EnumClass), this test allows inheriting from `Enum` by hand;
11111106
// see enum-List-control.scala.
11121107
report.error(ClassCannotExtendEnum(cls, firstParent), cdef.sourcePos)
1108+
val enumCls = enumCase.owner.linkedClass
1109+
if !firstParent.derivesFrom(enumCls) then
1110+
report.error(i"enum case does not extend its enum $enumCls", enumCase.sourcePos)
1111+
false
1112+
else
1113+
true
11131114
}
11141115

11151116
/** Check that all references coming from enum cases in an enum companion object
@@ -1205,7 +1206,7 @@ trait Checking {
12051206

12061207
trait ReChecking extends Checking {
12071208
import tpd._
1208-
override def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = ()
1209+
override def checkEnum(cdef: untpd.TypeDef, cls: Symbol, enumCase: Symbol, firstParent: Symbol)(using Context): Boolean = true
12091210
override def checkRefsLegal(tree: tpd.Tree, badOwner: Symbol, allowed: (Name, Symbol) => Boolean, where: String)(using Context): Unit = ()
12101211
override def checkFullyAppliedType(tree: Tree)(using Context): Unit = ()
12111212
override def checkEnumCaseRefsLegal(cdef: TypeDef, enumCtx: Context)(using Context): Unit = ()

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,24 +2071,27 @@ class Typer extends Namer
20712071
val parentsWithClass = ensureFirstTreeIsClass(parents.mapconserve(typedParent).filterConserve(!_.isEmpty), cdef.nameSpan)
20722072
val parents1 = ensureConstrCall(cls, parentsWithClass)(using superCtx)
20732073

2074+
var forceEmptyBody: Boolean = false
2075+
val enumCaseDef = optEnumCase(cls)
2076+
if enumCaseDef.exists then
2077+
val firstParent = parents1.head.tpe.dealias.typeSymbol
2078+
forceEmptyBody = !checkEnum(cdef, cls, enumCaseDef, firstParent) // don't bother looking inside the template
2079+
20742080
val self1 = typed(self)(using ctx.outer).asInstanceOf[ValDef] // outer context where class members are not visible
20752081
if (self1.tpt.tpe.isError || classExistsOnSelf(cls.unforcedDecls, self1))
20762082
// fail fast to avoid typing the body with an error type
20772083
cdef.withType(UnspecifiedErrorType)
20782084
else {
20792085
val dummy = localDummy(cls, impl)
2080-
val body1 = addAccessorDefs(cls,
2081-
typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)
2086+
val body1 =
2087+
if forceEmptyBody then Nil
2088+
else addAccessorDefs(cls, typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)
20822089

20832090
checkNoDoubleDeclaration(cls)
20842091
val impl1 = cpy.Template(impl)(constr1, parents1, Nil, self1, body1)
20852092
.withType(dummy.termRef)
20862093
if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper)
20872094
checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan))
2088-
if (cls.derivesFrom(defn.EnumClass)) {
2089-
val firstParent = parents1.head.tpe.dealias.typeSymbol
2090-
checkEnum(cdef, cls, firstParent)
2091-
}
20922095
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)
20932096

20942097
val reportDynamicInheritance =

library/src/scala/runtime/EnumValueSerializationProxy.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import java.security.PrivilegedActionException;
66
import java.security.PrivilegedExceptionAction;
77

8-
/** A serialization proxy for singleton enum values */
8+
/** A serialization proxy for singleton enum values, based on `scala.runtime.ModuleSerializationProxy` */
99
public final class EnumValueSerializationProxy implements Serializable {
1010
private static final long serialVersionUID = 1L;
1111
private final Class<?> enumClass;
@@ -17,17 +17,13 @@ protected Object[] computeValue(Class<?> type) {
1717
return AccessController.doPrivileged((PrivilegedExceptionAction<Object[]>) () ->
1818
(Object[])type.getMethod("values").invoke(null));
1919
} catch (PrivilegedActionException e) {
20-
return rethrowRuntime(e.getCause());
20+
Throwable cause = e.getCause();
21+
if (cause instanceof RuntimeException) throw (RuntimeException) cause;
22+
else throw new RuntimeException(cause);
2123
}
2224
}
2325
};
2426

25-
private static <T> T rethrowRuntime(Throwable e) {
26-
Throwable cause = e.getCause();
27-
if (cause instanceof RuntimeException) throw (RuntimeException) cause;
28-
else throw new RuntimeException(cause);
29-
}
30-
3127
public EnumValueSerializationProxy(Class<?> enumClass, int ordinal) {
3228
this.enumClass = enumClass;
3329
this.ordinal = ordinal;

0 commit comments

Comments
 (0)