diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index ef5d0d1f7022..86597087586d 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -551,6 +551,8 @@ class Definitions { case _ => false }).symbol.asTerm lazy val JavaSerializableClass = ctx.requiredClass("java.io.Serializable") + def readResolve(cls: ClassSymbol, flags: FlagSet) = enterMethod(cls, nme.readResolve, MethodType(Nil, AnyRefType), flags) + lazy val ComparableClass = ctx.requiredClass("java.lang.Comparable") lazy val SystemClass = ctx.requiredClass("java.lang.System") diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala index 86dcbbed2196..83962dc71de6 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala @@ -2,18 +2,13 @@ package dotty.tools.dotc package transform import core._ -import Symbols._, Types._, Contexts._, Names._, StdNames._, Constants._, SymUtils._ -import scala.collection.{ mutable, immutable } +import Symbols._, Types._, Contexts._, StdNames._, Constants._, SymUtils._ import Flags._ -import MegaPhase._ import DenotTransformers._ -import ast.Trees._ -import ast.untpd import Decorators._ import NameOps._ import Annotations.Annotation import ValueClasses.isDerivedValueClass -import scala.collection.mutable.ListBuffer import scala.language.postfixOps /** Synthetic method implementations for case classes, case objects, @@ -57,7 +52,7 @@ class SyntheticMethods(thisPhase: DenotTransformer) { def caseModuleSymbols(implicit ctx: Context) = { initSymbols; myCaseModuleSymbols } /** The synthetic methods of the case or value class `clazz`. */ - def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = { + def syntheticMethods(clazz: ClassSymbol, isSerializableObject: Boolean)(implicit ctx: Context): List[Tree] = { val clazzType = clazz.appliedRef lazy val accessors = if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased` @@ -261,12 +256,33 @@ class SyntheticMethods(thisPhase: DenotTransformer) { */ def canEqualBody(that: Tree): Tree = that.isInstance(AnnotatedType(clazzType, Annotation(defn.UncheckedAnnot))) - symbolsToSynthesize flatMap syntheticDefIfMissing + val methods = symbolsToSynthesize flatMap syntheticDefIfMissing + + def createReadResolveMethod(implicit ctx: Context): Tree = { + ctx.log(s"adding readResolve to $clazz at ${ctx.phase}") + val readResolve = defn.readResolve(clazz, Private | Synthetic) + DefDef(readResolve, _ => ref(clazz.sourceModule)).withPos(ctx.owner.pos.focus) + } + + if (isSerializableObject) + createReadResolveMethod :: methods + else + methods } - def addSyntheticMethods(impl: Template)(implicit ctx: Context) = - if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner)) - cpy.Template(impl)(body = impl.body ++ syntheticMethods(ctx.owner.asClass)) + def addSyntheticMethods(impl: Template)(implicit ctx: Context) = { + val isSerializableObject = + (ctx.owner.is(Module) + && ctx.owner.isStatic + && ctx.owner.derivesFrom(defn.JavaSerializableClass) + && !ctx.owner.asClass.membersNamed(nme.readResolve) + .filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false)) + .exists) + + if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner) || isSerializableObject) + cpy.Template(impl)(body = impl.body ++ syntheticMethods(ctx.owner.asClass, isSerializableObject)) else impl + } + } diff --git a/tests/run/serialize.scala b/tests/run/serialize.scala index 3c97892ae5ec..8061e2a0c6c8 100644 --- a/tests/run/serialize.scala +++ b/tests/run/serialize.scala @@ -8,9 +8,23 @@ object Test { in.readObject.asInstanceOf[T] } + object Foo extends Serializable {} + + object Baz extends Serializable { + private def readResolve(): AnyRef = { + this + } + } + def main(args: Array[String]): Unit = { val x: PartialFunction[Int, Int] = { case x => x + 1 } val adder = serializeDeserialize(x) assert(adder(1) == 2) + + val foo = serializeDeserialize(Foo) + assert(foo eq Foo) + + val baz = serializeDeserialize(Baz) + assert(baz ne Baz) } }