@@ -1102,21 +1102,31 @@ trait Checking {
1102
1102
report.error(ClassCannotExtendEnum (cls, firstParent), cdef.sourcePos)
1103
1103
}
1104
1104
1105
- /** Check that the firstParent derives from the declaring enum class.
1105
+ /** Check that the firstParent derives from the declaring enum class, if not, adds it as a parent after emitting an
1106
+ * error.
1106
1107
*/
1107
- def checkEnumParent (cls : Symbol , firstParent : Symbol )(using Context ): Boolean = {
1108
+ def checkEnumParent (cls : Symbol , firstParent : Symbol )(using Context ): Unit =
1109
+
1110
+ extension (sym : Symbol ) def typeRefApplied (using Context ): Type =
1111
+ typeRef.appliedTo(typeParams.map(_.info.loBound))
1112
+
1113
+ def ensureParentDerivesFrom (enumCase : Symbol )(using Context ) =
1114
+ val enumCls = enumCase.owner.linkedClass
1115
+ if ! firstParent.derivesFrom(enumCls) then
1116
+ report.error(i " enum case does not extend its enum $enumCls" , enumCase.sourcePos)
1117
+ cls.info match
1118
+ case info : ClassInfo =>
1119
+ cls.info = info.derivedClassInfo(classParents = enumCls.typeRefApplied :: info.classParents)
1120
+ case _ =>
1121
+
1108
1122
val enumCase =
1109
1123
if cls.isAllOf(EnumCase ) then cls
1110
1124
else if cls.isAnonymousClass && cls.owner.isAllOf(EnumCase ) then cls.owner
1111
1125
else NoSymbol
1112
- def parentDerivesFrom (enumCls : Symbol )(using Context ) =
1113
- if ! firstParent.derivesFrom(enumCls) then
1114
- report.error(i " enum case does not extend its enum $enumCls" , enumCase.sourcePos)
1115
- false
1116
- else
1117
- true
1118
- ! enumCase.exists || parentDerivesFrom(enumCase.owner.linkedClass)
1119
- }
1126
+ if enumCase.exists then
1127
+ ensureParentDerivesFrom(enumCase)
1128
+
1129
+ end checkEnumParent
1120
1130
1121
1131
/** Check that all references coming from enum cases in an enum companion object
1122
1132
* are legal.
@@ -1211,7 +1221,7 @@ trait Checking {
1211
1221
1212
1222
trait ReChecking extends Checking {
1213
1223
import tpd ._
1214
- override def checkEnumParent (cls : Symbol , firstParent : Symbol )(using Context ): Boolean = true
1224
+ override def checkEnumParent (cls : Symbol , firstParent : Symbol )(using Context ): Unit = ()
1215
1225
override def checkEnum (cdef : untpd.TypeDef , cls : Symbol , firstParent : Symbol )(using Context ): Unit = ()
1216
1226
override def checkRefsLegal (tree : tpd.Tree , badOwner : Symbol , allowed : (Name , Symbol ) => Boolean , where : String )(using Context ): Unit = ()
1217
1227
override def checkFullyAppliedType (tree : Tree )(using Context ): Unit = ()
0 commit comments