From 814a6d15e80bea735d17e1e02d3a125f1a9fba08 Mon Sep 17 00:00:00 2001 From: Ingar Abrahamsen Date: Thu, 3 May 2018 14:39:47 +0200 Subject: [PATCH] add readResolve to objects with serializable interface If the object implements the serializable interface and not includes the readResolve method then we add it. Since it's an object we reference it to the objects singleton instance. Resolves issue #4440 --- .../dotty/tools/dotc/core/Definitions.scala | 2 + .../dotc/transform/SyntheticMethods.scala | 38 +++++++++++++------ tests/run/serialize.scala | 14 +++++++ 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index ef5d0d1f7022..86597087586d 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -551,6 +551,8 @@ class Definitions { case _ => false }).symbol.asTerm lazy val JavaSerializableClass = ctx.requiredClass("java.io.Serializable") + def readResolve(cls: ClassSymbol, flags: FlagSet) = enterMethod(cls, nme.readResolve, MethodType(Nil, AnyRefType), flags) + lazy val ComparableClass = ctx.requiredClass("java.lang.Comparable") lazy val SystemClass = ctx.requiredClass("java.lang.System") diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala index 86dcbbed2196..83962dc71de6 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala @@ -2,18 +2,13 @@ package dotty.tools.dotc package transform import core._ -import Symbols._, Types._, Contexts._, Names._, StdNames._, Constants._, SymUtils._ -import scala.collection.{ mutable, immutable } +import Symbols._, Types._, Contexts._, StdNames._, Constants._, SymUtils._ import Flags._ -import MegaPhase._ import DenotTransformers._ -import ast.Trees._ -import ast.untpd import Decorators._ import NameOps._ import Annotations.Annotation import ValueClasses.isDerivedValueClass -import scala.collection.mutable.ListBuffer import scala.language.postfixOps /** Synthetic method implementations for case classes, case objects, @@ -57,7 +52,7 @@ class SyntheticMethods(thisPhase: DenotTransformer) { def caseModuleSymbols(implicit ctx: Context) = { initSymbols; myCaseModuleSymbols } /** The synthetic methods of the case or value class `clazz`. */ - def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = { + def syntheticMethods(clazz: ClassSymbol, isSerializableObject: Boolean)(implicit ctx: Context): List[Tree] = { val clazzType = clazz.appliedRef lazy val accessors = if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased` @@ -261,12 +256,33 @@ class SyntheticMethods(thisPhase: DenotTransformer) { */ def canEqualBody(that: Tree): Tree = that.isInstance(AnnotatedType(clazzType, Annotation(defn.UncheckedAnnot))) - symbolsToSynthesize flatMap syntheticDefIfMissing + val methods = symbolsToSynthesize flatMap syntheticDefIfMissing + + def createReadResolveMethod(implicit ctx: Context): Tree = { + ctx.log(s"adding readResolve to $clazz at ${ctx.phase}") + val readResolve = defn.readResolve(clazz, Private | Synthetic) + DefDef(readResolve, _ => ref(clazz.sourceModule)).withPos(ctx.owner.pos.focus) + } + + if (isSerializableObject) + createReadResolveMethod :: methods + else + methods } - def addSyntheticMethods(impl: Template)(implicit ctx: Context) = - if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner)) - cpy.Template(impl)(body = impl.body ++ syntheticMethods(ctx.owner.asClass)) + def addSyntheticMethods(impl: Template)(implicit ctx: Context) = { + val isSerializableObject = + (ctx.owner.is(Module) + && ctx.owner.isStatic + && ctx.owner.derivesFrom(defn.JavaSerializableClass) + && !ctx.owner.asClass.membersNamed(nme.readResolve) + .filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false)) + .exists) + + if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner) || isSerializableObject) + cpy.Template(impl)(body = impl.body ++ syntheticMethods(ctx.owner.asClass, isSerializableObject)) else impl + } + } diff --git a/tests/run/serialize.scala b/tests/run/serialize.scala index 3c97892ae5ec..8061e2a0c6c8 100644 --- a/tests/run/serialize.scala +++ b/tests/run/serialize.scala @@ -8,9 +8,23 @@ object Test { in.readObject.asInstanceOf[T] } + object Foo extends Serializable {} + + object Baz extends Serializable { + private def readResolve(): AnyRef = { + this + } + } + def main(args: Array[String]): Unit = { val x: PartialFunction[Int, Int] = { case x => x + 1 } val adder = serializeDeserialize(x) assert(adder(1) == 2) + + val foo = serializeDeserialize(Foo) + assert(foo eq Foo) + + val baz = serializeDeserialize(Baz) + assert(baz ne Baz) } }