From d89e6be8e2c8e3b99b58747f104d5b3b188bdfb8 Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Thu, 3 May 2018 14:39:47 +0200 Subject: [PATCH 1/7] Fix #4440: Do not serialize the content of static objects In #4450 and Scala 2.12, readResolve is used to make sure deserializing an object returns the singleton instance of the object, but this doesn't prevent the fields of the objects from being serialized in the first place even though they're not used. Scala 2.13 switched to using writeReplace to completely bypass serialization of the object in https://github.com/scala/scala/pull/7297. This commit adapts this to Dotty. Co-Authored-By: Ingar Abrahamsen Co-Authored-By: Jason Zaugg --- .../dotty/tools/dotc/core/Definitions.scala | 7 +++ .../src/dotty/tools/dotc/core/StdNames.scala | 2 +- .../tools/dotc/core/SymDenotations.scala | 4 ++ .../dotc/transform/SyntheticMethods.scala | 47 ++++++++++++---- .../runtime/ModuleSerializationProxy.java | 54 +++++++++++++++++++ tests/run/literals.decompiled | 4 +- ...ule-serialization-proxy-class-unload.scala | 27 ++++++++++ tests/run/serialize.scala | 14 +++++ tests/run/tasty-extractors-2.check | 2 +- tests/run/valueclasses-pavlov.decompiled | 4 +- 10 files changed, 150 insertions(+), 15 deletions(-) create mode 100644 library/src/scala/runtime/ModuleSerializationProxy.java create mode 100644 tests/run/module-serialization-proxy-class-unload.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 6f63373f064c..e3956b3c8e72 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -575,7 +575,9 @@ class Definitions { case List(pt) => (pt isRef StringClass) case _ => false }).symbol.asTerm + lazy val JavaSerializableClass: ClassSymbol = ctx.requiredClass("java.io.Serializable") + lazy val ComparableClass: ClassSymbol = ctx.requiredClass("java.lang.Comparable") lazy val SystemClass: ClassSymbol = ctx.requiredClass("java.lang.System") @@ -656,6 +658,11 @@ class Definitions { lazy val Product_productPrefixR: TermRef = ProductClass.requiredMethodRef(nme.productPrefix) def Product_productPrefix(implicit ctx: Context): Symbol = Product_productPrefixR.symbol + lazy val ModuleSerializationProxyType: TypeRef = ctx.requiredClassRef("scala.runtime.ModuleSerializationProxy") + def ModuleSerializationProxyClass(implicit ctx: Context): ClassSymbol = ModuleSerializationProxyType.symbol.asClass + lazy val ModuleSerializationProxyConstructor: TermSymbol = + ModuleSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(WildcardType))) + lazy val GenericType: TypeRef = ctx.requiredClassRef("scala.reflect.Generic") def GenericClass(implicit ctx: Context): ClassSymbol = GenericType.symbol.asClass lazy val ShapeType: TypeRef = ctx.requiredClassRef("scala.compiletime.Shape") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 339d811f26d0..4dfbf09cf827 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -495,7 +495,6 @@ object StdNames { val productIterator: N = "productIterator" val productPrefix: N = "productPrefix" val raw_ : N = "raw" - val readResolve: N = "readResolve" val reflect: N = "reflect" val reflectiveSelectable: N = "reflectiveSelectable" val reify : N = "reify" @@ -558,6 +557,7 @@ object StdNames { val withFilterIfRefutable: N = "withFilterIfRefutable$" val WorksheetWrapper: N = "WorksheetWrapper" val wrap: N = "wrap" + val writeReplace: N = "writeReplace" val zero: N = "zero" val zip: N = "zip" val nothingRuntimeClass: N = "scala.runtime.Nothing$" diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index 34a6870a4574..3f50bf37cdf8 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -680,6 +680,10 @@ object SymDenotations { */ def derivesFrom(base: Symbol)(implicit ctx: Context): Boolean = false + /** Is this symbol a class that extends `java.io.Serializable` ? */ + def isSerializable(implicit ctx: Context): Boolean = + isClass && derivesFrom(defn.JavaSerializableClass) + /** Is this symbol a class that extends `AnyVal`? */ final def isValueClass(implicit ctx: Context): Boolean = { val di = initial diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala index 04b68a932ac0..0a35ad44d3fe 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala @@ -23,8 +23,9 @@ import ValueClasses.isDerivedValueClass * def productArity: Int * def productPrefix: String * - * Special handling: - * protected def readResolve(): AnyRef + * Add to serializable static objects, unless an implementation + * already exists: + * private def writeReplace(): AnyRef * * Selectively added to value classes, unless a non-default * implementation already exists: @@ -50,8 +51,10 @@ class SyntheticMethods(thisPhase: DenotTransformer) { def caseSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseSymbols } def caseModuleSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseModuleSymbols } - /** The synthetic methods of the case or value class `clazz`. */ - def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = { + /** If this is a case or value class, return the appropriate additional methods, + * otherwise return nothing. + */ + def caseAndValueMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = { val clazzType = clazz.appliedRef lazy val accessors = if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased` @@ -255,12 +258,38 @@ class SyntheticMethods(thisPhase: DenotTransformer) { */ def canEqualBody(that: Tree): Tree = that.isInstance(AnnotatedType(clazzType, Annotation(defn.UncheckedAnnot))) - symbolsToSynthesize flatMap syntheticDefIfMissing + symbolsToSynthesize.flatMap(syntheticDefIfMissing) } - def addSyntheticMethods(impl: Template)(implicit ctx: Context): Template = - if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner)) - cpy.Template(impl)(body = impl.body ++ syntheticMethods(ctx.owner.asClass)) + /** If this is a serializable static object `Foo`, add the method: + * + * private def writeReplace(): AnyRef = + * new scala.runtime.ModuleSerializationProxy(classOf[Foo$]) + * + * unless an implementation already exists, otherwise do nothing. + */ + def serializableObjectMethod(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = { + def hasWriteReplace: Boolean = + clazz.membersNamed(nme.writeReplace) + .filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false)) + .exists + if (clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace) { + val writeReplace = ctx.newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic, + MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm + List( + DefDef(writeReplace, + _ => New(defn.ModuleSerializationProxyType, + defn.ModuleSerializationProxyConstructor, + List(Literal(Constant(clazz.typeRef))))) + .withSpan(ctx.owner.span.focus)) + } else - impl + Nil + } + + def addSyntheticMethods(impl: Template)(implicit ctx: Context): Template = { + val clazz = ctx.owner.asClass + cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body) + } + } diff --git a/library/src/scala/runtime/ModuleSerializationProxy.java b/library/src/scala/runtime/ModuleSerializationProxy.java new file mode 100644 index 000000000000..3f82a65c5ba5 --- /dev/null +++ b/library/src/scala/runtime/ModuleSerializationProxy.java @@ -0,0 +1,54 @@ +// Copied from https://github.com/scala/scala/blob/2.13.x/src/library/scala/runtime/ModuleSerializationProxy.java +// TODO: Remove this file once we switch to the Scala 2.13 stdlib since it already contains it. + +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala.runtime; + +import java.io.Serializable; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashSet; +import java.util.Set; + +/** A serialization proxy for singleton objects */ +public final class ModuleSerializationProxy implements Serializable { + private static final long serialVersionUID = 1L; + private final Class moduleClass; + private static final ClassValue instances = new ClassValue() { + @Override + protected Object computeValue(Class type) { + try { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> type.getField("MODULE$").get(null)); + } catch (PrivilegedActionException e) { + return rethrowRuntime(e.getCause()); + } + } + }; + + private static Object rethrowRuntime(Throwable e) { + Throwable cause = e.getCause(); + if (cause instanceof RuntimeException) throw (RuntimeException) cause; + else throw new RuntimeException(cause); + } + + public ModuleSerializationProxy(Class moduleClass) { + this.moduleClass = moduleClass; + } + + @SuppressWarnings("unused") + private Object readResolve() { + return instances.get(moduleClass); + } +} diff --git a/tests/run/literals.decompiled b/tests/run/literals.decompiled index 67279692e269..a82862e56607 100644 --- a/tests/run/literals.decompiled +++ b/tests/run/literals.decompiled @@ -2,7 +2,6 @@ object Test { def αρετη: java.lang.String = "alpha rho epsilon tau eta" case class GGG(i: scala.Int) { - def αα(that: Test.GGG): scala.Int = GGG.this.i.+(that.i) override def hashCode(): scala.Int = { var acc: scala.Int = 767242539 acc = scala.runtime.Statics.mix(acc, GGG.this.i) @@ -24,6 +23,7 @@ object Test { case _ => throw new java.lang.IndexOutOfBoundsException(n.toString()) } + def αα(that: Test.GGG): scala.Int = GGG.this.i.+(that.i) } object GGG extends scala.Function1[scala.Int, Test.GGG] def check_success[a](name: scala.Predef.String, closure: => a, expected: a): scala.Unit = { @@ -95,4 +95,4 @@ object Test { val ggg: scala.Int = Test.GGG.apply(1).αα(Test.GGG.apply(2)) Test.check_success[scala.Int]("ggg == 3", ggg, 3) } -} +} \ No newline at end of file diff --git a/tests/run/module-serialization-proxy-class-unload.scala b/tests/run/module-serialization-proxy-class-unload.scala new file mode 100644 index 000000000000..ec47e92298be --- /dev/null +++ b/tests/run/module-serialization-proxy-class-unload.scala @@ -0,0 +1,27 @@ +import java.io.File + +object Module { + val data = new Array[Byte](32 * 1024 * 1024) +} + +object Test { + private val readResolve = classOf[scala.runtime.ModuleSerializationProxy].getDeclaredMethod("readResolve") + readResolve.setAccessible(true) + + val testClassesDir = new File(Module.getClass.getClassLoader.getResource("Module.class").toURI).getParentFile + def main(args: Array[String]): Unit = { + for (i <- 1 to 256) { + // This would "java.lang.OutOfMemoryError: Java heap space" if ModuleSerializationProxy + // prevented class unloading. + deserializeDynamicLoadedClass() + } + } + + def deserializeDynamicLoadedClass(): Unit = { + val loader = new java.net.URLClassLoader(Array(testClassesDir.toURI.toURL), ClassLoader.getSystemClassLoader) + val moduleClass = loader.loadClass("Module$") + assert(moduleClass ne Module.getClass) + val result = readResolve.invoke(new scala.runtime.ModuleSerializationProxy(moduleClass)) + assert(result.getClass == moduleClass) + } +} diff --git a/tests/run/serialize.scala b/tests/run/serialize.scala index 3c97892ae5ec..8e6ded17277c 100644 --- a/tests/run/serialize.scala +++ b/tests/run/serialize.scala @@ -8,9 +8,23 @@ object Test { in.readObject.asInstanceOf[T] } + object Foo extends Serializable {} + + object Baz extends Serializable { + private def writeReplace(): AnyRef = { + this + } + } + def main(args: Array[String]): Unit = { val x: PartialFunction[Int, Int] = { case x => x + 1 } val adder = serializeDeserialize(x) assert(adder(1) == 2) + + val foo = serializeDeserialize(Foo) + assert(foo eq Foo) + + val baz = serializeDeserialize(Baz) + assert(baz ne Baz) } } diff --git a/tests/run/tasty-extractors-2.check b/tests/run/tasty-extractors-2.check index 1e403bb71054..d566b57c87dc 100644 --- a/tests/run/tasty-extractors-2.check +++ b/tests/run/tasty-extractors-2.check @@ -49,7 +49,7 @@ Type.SymRef(IsClassSymbol(), Type.ThisType(Type.SymRef(IsPackageSymb Term.Inlined(None, Nil, Term.Block(List(ClassDef("Foo", DefDef("", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil)), Nil, None, List(DefDef("a", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.Int(0))))))), Term.Literal(Constant.Unit()))) Type.SymRef(IsClassSymbol(), Type.ThisType(Type.SymRef(IsPackageSymbol(), NoPrefix()))) -Term.Inlined(None, Nil, Term.Block(List(ClassDef("Foo", DefDef("", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil), TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Product"), TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Serializable")), Nil, None, List(DefDef("productElementName", Nil, List(List(ValDef("x$1", TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Int"), None))), TypeTree.Select(Term.Select(Term.Ident("java"), "lang"), "String"), Some(Term.Match(Term.Ident("x$1"), List(CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Apply(Term.Ident("throw"), List(Term.Apply(Term.Select(Term.New(TypeTree.Select(Term.Select(Term.Ident("java"), "lang"), "IndexOutOfBoundsException")), ""), List(Term.Apply(Term.Select(Term.Select(Term.Select(Term.Ident("java"), "lang"), "String"), "valueOf"), List(Term.Ident("x$1")))))))))))), DefDef("copy", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil))), DefDef("hashCode", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Literal(Constant.Int(394005536)))), DefDef("equals", Nil, List(List(ValDef("x$0", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.Apply(Term.Select(Term.This(Some(Id("Foo"))), "eq"), List(Term.TypeApply(Term.Select(Term.Ident("x$0"), "asInstanceOf"), List(TypeTree.Inferred())))), "||"), List(Term.Match(Term.Ident("x$0"), List(CaseDef(Pattern.Bind("x$0", Pattern.TypeTest(TypeTree.Inferred())), None, Term.Literal(Constant.Boolean(true))), CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Literal(Constant.Boolean(false))))))))), DefDef("toString", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Ident("_toString"), List(Term.This(Some(Id("Foo"))))))), DefDef("canEqual", Nil, List(List(ValDef("that", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.TypeApply(Term.Select(Term.Ident("that"), "isInstanceOf"), List(TypeTree.Inferred())))), DefDef("productArity", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.Int(0)))), DefDef("productPrefix", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.String("Foo")))), DefDef("productElement", Nil, List(List(ValDef("n", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Match(Term.Ident("n"), List(CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Apply(Term.Ident("throw"), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), List(Term.Apply(Term.Select(Term.Ident("n"), "toString"), Nil)))))))))))), ValDef("Foo", TypeTree.Ident("Foo$"), Some(Term.Apply(Term.Select(Term.New(TypeTree.Ident("Foo$")), ""), Nil))), ClassDef("Foo$", DefDef("", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil), TypeTree.Applied(TypeTree.Inferred(), List(TypeTree.Inferred()))), Nil, Some(ValDef("_", TypeTree.Singleton(Term.Ident("Foo")), None)), List(DefDef("apply", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil))), DefDef("unapply", Nil, List(List(ValDef("x$1", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Literal(Constant.Boolean(true))))))), Term.Literal(Constant.Unit()))) +Term.Inlined(None, Nil, Term.Block(List(ClassDef("Foo", DefDef("", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil), TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Product"), TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Serializable")), Nil, None, List(DefDef("hashCode", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Literal(Constant.Int(394005536)))), DefDef("equals", Nil, List(List(ValDef("x$0", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.Apply(Term.Select(Term.This(Some(Id("Foo"))), "eq"), List(Term.TypeApply(Term.Select(Term.Ident("x$0"), "asInstanceOf"), List(TypeTree.Inferred())))), "||"), List(Term.Match(Term.Ident("x$0"), List(CaseDef(Pattern.Bind("x$0", Pattern.TypeTest(TypeTree.Inferred())), None, Term.Literal(Constant.Boolean(true))), CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Literal(Constant.Boolean(false))))))))), DefDef("toString", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Ident("_toString"), List(Term.This(Some(Id("Foo"))))))), DefDef("canEqual", Nil, List(List(ValDef("that", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.TypeApply(Term.Select(Term.Ident("that"), "isInstanceOf"), List(TypeTree.Inferred())))), DefDef("productArity", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.Int(0)))), DefDef("productPrefix", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.String("Foo")))), DefDef("productElement", Nil, List(List(ValDef("n", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Match(Term.Ident("n"), List(CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Apply(Term.Ident("throw"), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), List(Term.Apply(Term.Select(Term.Ident("n"), "toString"), Nil)))))))))), DefDef("productElementName", Nil, List(List(ValDef("x$1", TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Int"), None))), TypeTree.Select(Term.Select(Term.Ident("java"), "lang"), "String"), Some(Term.Match(Term.Ident("x$1"), List(CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Apply(Term.Ident("throw"), List(Term.Apply(Term.Select(Term.New(TypeTree.Select(Term.Select(Term.Ident("java"), "lang"), "IndexOutOfBoundsException")), ""), List(Term.Apply(Term.Select(Term.Select(Term.Select(Term.Ident("java"), "lang"), "String"), "valueOf"), List(Term.Ident("x$1")))))))))))), DefDef("copy", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil))))), ValDef("Foo", TypeTree.Ident("Foo$"), Some(Term.Apply(Term.Select(Term.New(TypeTree.Ident("Foo$")), ""), Nil))), ClassDef("Foo$", DefDef("", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil), TypeTree.Applied(TypeTree.Inferred(), List(TypeTree.Inferred()))), Nil, Some(ValDef("_", TypeTree.Singleton(Term.Ident("Foo")), None)), List(DefDef("apply", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil))), DefDef("unapply", Nil, List(List(ValDef("x$1", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Literal(Constant.Boolean(true))))))), Term.Literal(Constant.Unit()))) Type.SymRef(IsClassSymbol(), Type.ThisType(Type.SymRef(IsPackageSymbol(), NoPrefix()))) Term.Inlined(None, Nil, Term.Block(List(ClassDef("Foo1", DefDef("", Nil, List(List(ValDef("a", TypeTree.Ident("Int"), None))), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), ""), Nil)), Nil, None, List(ValDef("a", TypeTree.Inferred(), None)))), Term.Literal(Constant.Unit()))) diff --git a/tests/run/valueclasses-pavlov.decompiled b/tests/run/valueclasses-pavlov.decompiled index 14579d1773fb..9f8453e1303c 100644 --- a/tests/run/valueclasses-pavlov.decompiled +++ b/tests/run/valueclasses-pavlov.decompiled @@ -11,8 +11,6 @@ final class Box1(val value: scala.Predef.String) extends scala.AnyVal { object Box1 extends scala.AnyRef /** Decompiled from out/runTestFromTasty/run/valueclasses-pavlov/Box2.tasty */ final class Box2(val value: scala.Predef.String) extends scala.AnyVal with Foo { - def box1(x: Box1): scala.Predef.String = "box1: ok" - def box2(x: Box2): scala.Predef.String = "box2: ok" override def hashCode(): scala.Int = Box2.this.value.hashCode() override def equals(x$0: scala.Any): scala.Boolean = x$0 match { case x$0: Box2 @scala.unchecked => @@ -20,6 +18,8 @@ final class Box2(val value: scala.Predef.String) extends scala.AnyVal with Foo { case _ => false } + def box1(x: Box1): scala.Predef.String = "box1: ok" + def box2(x: Box2): scala.Predef.String = "box2: ok" } object Box2 extends scala.AnyRef /** Decompiled from out/runTestFromTasty/run/valueclasses-pavlov/C.tasty */ From 9b3cccdb1cf6ddaf590cb767a501f74fb4148310 Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Sat, 26 Jan 2019 14:10:19 +0100 Subject: [PATCH 2/7] Allow `classOf[Foo.type]` if `Foo` is an object And use this in SyntheticMethods instead of the impossible-to-write-in-user-code `classOf[Foo$]`. This is necessary if we want decompilation to produce valid source code for objects. Note that the decompilation printer is currently wrong here and will print `classOf[Foo]` instead of `classOf[Foo.type]`. --- compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala | 2 +- .../src/dotty/tools/dotc/transform/SyntheticMethods.scala | 4 ++-- compiler/src/dotty/tools/dotc/typer/Applications.scala | 3 ++- tests/neg/classOf.scala | 2 +- tests/run/classof-object.scala | 5 +++++ 5 files changed, 11 insertions(+), 5 deletions(-) create mode 100644 tests/run/classof-object.scala diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index e45ab1c4a486..e3f22bd696f8 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -472,7 +472,7 @@ class PlainPrinter(_ctx: Context) extends Printer { def toText(const: Constant): Text = const.tag match { case StringTag => stringText("\"" + escapedString(const.value.toString) + "\"") - case ClazzTag => "classOf[" ~ toText(const.typeValue.classSymbol) ~ "]" + case ClazzTag => "classOf[" ~ toText(const.typeValue) ~ "]" case CharTag => literalText(s"'${escapedChar(const.charValue)}'") case LongTag => literalText(const.longValue.toString + "L") case EnumTag => literalText(const.symbolValue.name.toString) diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala index 0a35ad44d3fe..937ce128e2e1 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala @@ -264,7 +264,7 @@ class SyntheticMethods(thisPhase: DenotTransformer) { /** If this is a serializable static object `Foo`, add the method: * * private def writeReplace(): AnyRef = - * new scala.runtime.ModuleSerializationProxy(classOf[Foo$]) + * new scala.runtime.ModuleSerializationProxy(classOf[Foo.type]) * * unless an implementation already exists, otherwise do nothing. */ @@ -280,7 +280,7 @@ class SyntheticMethods(thisPhase: DenotTransformer) { DefDef(writeReplace, _ => New(defn.ModuleSerializationProxyType, defn.ModuleSerializationProxyConstructor, - List(Literal(Constant(clazz.typeRef))))) + List(Literal(Constant(clazz.sourceModule.termRef))))) .withSpan(ctx.owner.span.focus)) } else diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 0a3af9b047f0..e9fb3adb8e87 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -907,7 +907,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic => if (typedArgs.length <= pt.paramInfos.length && !isNamed) if (typedFn.symbol == defn.Predef_classOf && typedArgs.nonEmpty) { val arg = typedArgs.head - checkClassType(arg.tpe, arg.sourcePos, traitReq = false, stablePrefixReq = false) + if (!arg.symbol.is(Module)) // Allow `classOf[Foo.type]` if `Foo` is an object + checkClassType(arg.tpe, arg.sourcePos, traitReq = false, stablePrefixReq = false) } case _ => } diff --git a/tests/neg/classOf.scala b/tests/neg/classOf.scala index e13cf71c43a2..ece08f086bb3 100644 --- a/tests/neg/classOf.scala +++ b/tests/neg/classOf.scala @@ -5,7 +5,7 @@ object Test { def f1[T] = classOf[T] // error def f2[T <: String] = classOf[T] // error - val x = classOf[Test.type] // error + val x = classOf[Test.type] // ok val y = classOf[C { type I = String }] // error val z = classOf[A] // ok } diff --git a/tests/run/classof-object.scala b/tests/run/classof-object.scala new file mode 100644 index 000000000000..3ad3b92242f3 --- /dev/null +++ b/tests/run/classof-object.scala @@ -0,0 +1,5 @@ +object Test { + def main(args: Array[String]): Unit = { + assert(classOf[Test.type] == Test.getClass) + } +} From 84ba7123d974c912733969798e77d76b7daf593b Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Sat, 26 Jan 2019 15:44:57 +0100 Subject: [PATCH 3/7] Do not include writeReplace when decompiling --- library/src/scala/tasty/reflect/Printers.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/library/src/scala/tasty/reflect/Printers.scala b/library/src/scala/tasty/reflect/Printers.scala index e9c10e6462c2..716c6101d414 100644 --- a/library/src/scala/tasty/reflect/Printers.scala +++ b/library/src/scala/tasty/reflect/Printers.scala @@ -652,6 +652,7 @@ trait Printers n == "copy" || n.matches("copy\\$default\\$[1-9][0-9]*") || // default parameters for the copy method n.matches("_[1-9][0-9]*") || // Getters from Product + n == "writeReplace" || // Case class serialization n == "productElementName" case _ => false }) From 72ee85565de5df1e893ffc6cf3138b4724f25f4e Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Mon, 28 Jan 2019 08:23:20 +0100 Subject: [PATCH 4/7] Add regression test for classOf on an object. --- tests/run/classof-object.decompiled | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tests/run/classof-object.decompiled diff --git a/tests/run/classof-object.decompiled b/tests/run/classof-object.decompiled new file mode 100644 index 000000000000..cd10ecdba5f7 --- /dev/null +++ b/tests/run/classof-object.decompiled @@ -0,0 +1,4 @@ +/** Decompiled from out/runTestFromTasty/run/classof-object/Test.tasty */ +object Test { + def main(args: scala.Array[scala.Predef.String]): scala.Unit = if (scala.Predef.classOf[Test.type].==(Test.getClass()).unary_!) dotty.DottyPredef.assertFail() else () +} From 140f7e1871ba67d3ef9d63f4f61a34109338a133 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Mon, 28 Jan 2019 08:25:44 +0100 Subject: [PATCH 5/7] Update method name --- library/src/scala/tasty/reflect/Printers.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/src/scala/tasty/reflect/Printers.scala b/library/src/scala/tasty/reflect/Printers.scala index 716c6101d414..f3119af2e46c 100644 --- a/library/src/scala/tasty/reflect/Printers.scala +++ b/library/src/scala/tasty/reflect/Printers.scala @@ -643,7 +643,7 @@ trait Printers def keepDefinition(d: Definition): Boolean = { val flags = d.symbol.flags - def isCaseClassUnOverridableMethod: Boolean = { + def isUndecompilableCaseClassMethod: Boolean = { // Currently the compiler does not allow overriding some of the methods generated for case classes d.symbol.flags.is(Flags.Synthetic) && (d match { @@ -658,7 +658,7 @@ trait Printers }) } def isInnerModuleObject = d.symbol.flags.is(Flags.Lazy) && d.symbol.flags.is(Flags.Object) - !flags.is(Flags.Param) && !flags.is(Flags.ParamAccessor) && !flags.is(Flags.FieldAccessor) && !isCaseClassUnOverridableMethod && !isInnerModuleObject + !flags.is(Flags.Param) && !flags.is(Flags.ParamAccessor) && !flags.is(Flags.FieldAccessor) && !isUndecompilableCaseClassMethod && !isInnerModuleObject } val stats1 = stats.collect { case IsDefinition(stat) if keepDefinition(stat) => stat From 191442c46c05c2e2c7ae99ff20fece2de4cb0e25 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Mon, 28 Jan 2019 10:33:18 +0100 Subject: [PATCH 6/7] Do not decompile synthetic writeReplace in objects --- library/src/scala/tasty/reflect/Printers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/src/scala/tasty/reflect/Printers.scala b/library/src/scala/tasty/reflect/Printers.scala index f3119af2e46c..9b3e8a72da2d 100644 --- a/library/src/scala/tasty/reflect/Printers.scala +++ b/library/src/scala/tasty/reflect/Printers.scala @@ -647,7 +647,7 @@ trait Printers // Currently the compiler does not allow overriding some of the methods generated for case classes d.symbol.flags.is(Flags.Synthetic) && (d match { - case DefDef("apply" | "unapply", _, _, _, _) if d.symbol.owner.flags.is(Flags.Object) => true + case DefDef("apply" | "unapply" | "writeReplace", _, _, _, _) if d.symbol.owner.flags.is(Flags.Object) => true case DefDef(n, _, _, _, _) if d.symbol.owner.flags.is(Flags.Case) => n == "copy" || n.matches("copy\\$default\\$[1-9][0-9]*") || // default parameters for the copy method From 95900de5a0ee920cae8ddc74eab9737c56e713bb Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Fri, 1 Feb 2019 10:56:47 +0100 Subject: [PATCH 7/7] Remove unnecessary check --- library/src/scala/tasty/reflect/Printers.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/library/src/scala/tasty/reflect/Printers.scala b/library/src/scala/tasty/reflect/Printers.scala index 9b3e8a72da2d..000eff2a8ddc 100644 --- a/library/src/scala/tasty/reflect/Printers.scala +++ b/library/src/scala/tasty/reflect/Printers.scala @@ -652,7 +652,6 @@ trait Printers n == "copy" || n.matches("copy\\$default\\$[1-9][0-9]*") || // default parameters for the copy method n.matches("_[1-9][0-9]*") || // Getters from Product - n == "writeReplace" || // Case class serialization n == "productElementName" case _ => false })