Skip to content

Commit d89e6be

Browse files
smarteringarabrretronym
committed
Fix #4440: Do not serialize the content of static objects
In #4450 and Scala 2.12, readResolve is used to make sure deserializing an object returns the singleton instance of the object, but this doesn't prevent the fields of the objects from being serialized in the first place even though they're not used. Scala 2.13 switched to using writeReplace to completely bypass serialization of the object in scala/scala#7297. This commit adapts this to Dotty. Co-Authored-By: Ingar Abrahamsen <[email protected]> Co-Authored-By: Jason Zaugg <[email protected]>
1 parent f9e58c9 commit d89e6be

10 files changed

+150
-15
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+7
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,9 @@ class Definitions {
575575
case List(pt) => (pt isRef StringClass)
576576
case _ => false
577577
}).symbol.asTerm
578+
578579
lazy val JavaSerializableClass: ClassSymbol = ctx.requiredClass("java.io.Serializable")
580+
579581
lazy val ComparableClass: ClassSymbol = ctx.requiredClass("java.lang.Comparable")
580582

581583
lazy val SystemClass: ClassSymbol = ctx.requiredClass("java.lang.System")
@@ -656,6 +658,11 @@ class Definitions {
656658
lazy val Product_productPrefixR: TermRef = ProductClass.requiredMethodRef(nme.productPrefix)
657659
def Product_productPrefix(implicit ctx: Context): Symbol = Product_productPrefixR.symbol
658660

661+
lazy val ModuleSerializationProxyType: TypeRef = ctx.requiredClassRef("scala.runtime.ModuleSerializationProxy")
662+
def ModuleSerializationProxyClass(implicit ctx: Context): ClassSymbol = ModuleSerializationProxyType.symbol.asClass
663+
lazy val ModuleSerializationProxyConstructor: TermSymbol =
664+
ModuleSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(WildcardType)))
665+
659666
lazy val GenericType: TypeRef = ctx.requiredClassRef("scala.reflect.Generic")
660667
def GenericClass(implicit ctx: Context): ClassSymbol = GenericType.symbol.asClass
661668
lazy val ShapeType: TypeRef = ctx.requiredClassRef("scala.compiletime.Shape")

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,6 @@ object StdNames {
495495
val productIterator: N = "productIterator"
496496
val productPrefix: N = "productPrefix"
497497
val raw_ : N = "raw"
498-
val readResolve: N = "readResolve"
499498
val reflect: N = "reflect"
500499
val reflectiveSelectable: N = "reflectiveSelectable"
501500
val reify : N = "reify"
@@ -558,6 +557,7 @@ object StdNames {
558557
val withFilterIfRefutable: N = "withFilterIfRefutable$"
559558
val WorksheetWrapper: N = "WorksheetWrapper"
560559
val wrap: N = "wrap"
560+
val writeReplace: N = "writeReplace"
561561
val zero: N = "zero"
562562
val zip: N = "zip"
563563
val nothingRuntimeClass: N = "scala.runtime.Nothing$"

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+4
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,10 @@ object SymDenotations {
680680
*/
681681
def derivesFrom(base: Symbol)(implicit ctx: Context): Boolean = false
682682

683+
/** Is this symbol a class that extends `java.io.Serializable` ? */
684+
def isSerializable(implicit ctx: Context): Boolean =
685+
isClass && derivesFrom(defn.JavaSerializableClass)
686+
683687
/** Is this symbol a class that extends `AnyVal`? */
684688
final def isValueClass(implicit ctx: Context): Boolean = {
685689
val di = initial

compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala

+38-9
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ import ValueClasses.isDerivedValueClass
2323
* def productArity: Int
2424
* def productPrefix: String
2525
*
26-
* Special handling:
27-
* protected def readResolve(): AnyRef
26+
* Add to serializable static objects, unless an implementation
27+
* already exists:
28+
* private def writeReplace(): AnyRef
2829
*
2930
* Selectively added to value classes, unless a non-default
3031
* implementation already exists:
@@ -50,8 +51,10 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
5051
def caseSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseSymbols }
5152
def caseModuleSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
5253

53-
/** The synthetic methods of the case or value class `clazz`. */
54-
def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
54+
/** If this is a case or value class, return the appropriate additional methods,
55+
* otherwise return nothing.
56+
*/
57+
def caseAndValueMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
5558
val clazzType = clazz.appliedRef
5659
lazy val accessors =
5760
if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased`
@@ -255,12 +258,38 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
255258
*/
256259
def canEqualBody(that: Tree): Tree = that.isInstance(AnnotatedType(clazzType, Annotation(defn.UncheckedAnnot)))
257260

258-
symbolsToSynthesize flatMap syntheticDefIfMissing
261+
symbolsToSynthesize.flatMap(syntheticDefIfMissing)
259262
}
260263

261-
def addSyntheticMethods(impl: Template)(implicit ctx: Context): Template =
262-
if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner))
263-
cpy.Template(impl)(body = impl.body ++ syntheticMethods(ctx.owner.asClass))
264+
/** If this is a serializable static object `Foo`, add the method:
265+
*
266+
* private def writeReplace(): AnyRef =
267+
* new scala.runtime.ModuleSerializationProxy(classOf[Foo$])
268+
*
269+
* unless an implementation already exists, otherwise do nothing.
270+
*/
271+
def serializableObjectMethod(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
272+
def hasWriteReplace: Boolean =
273+
clazz.membersNamed(nme.writeReplace)
274+
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
275+
.exists
276+
if (clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace) {
277+
val writeReplace = ctx.newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
278+
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
279+
List(
280+
DefDef(writeReplace,
281+
_ => New(defn.ModuleSerializationProxyType,
282+
defn.ModuleSerializationProxyConstructor,
283+
List(Literal(Constant(clazz.typeRef)))))
284+
.withSpan(ctx.owner.span.focus))
285+
}
264286
else
265-
impl
287+
Nil
288+
}
289+
290+
def addSyntheticMethods(impl: Template)(implicit ctx: Context): Template = {
291+
val clazz = ctx.owner.asClass
292+
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body)
293+
}
294+
266295
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copied from https://github.com/scala/scala/blob/2.13.x/src/library/scala/runtime/ModuleSerializationProxy.java
2+
// TODO: Remove this file once we switch to the Scala 2.13 stdlib since it already contains it.
3+
4+
/*
5+
* Scala (https://www.scala-lang.org)
6+
*
7+
* Copyright EPFL and Lightbend, Inc.
8+
*
9+
* Licensed under Apache License 2.0
10+
* (http://www.apache.org/licenses/LICENSE-2.0).
11+
*
12+
* See the NOTICE file distributed with this work for
13+
* additional information regarding copyright ownership.
14+
*/
15+
16+
package scala.runtime;
17+
18+
import java.io.Serializable;
19+
import java.security.AccessController;
20+
import java.security.PrivilegedActionException;
21+
import java.security.PrivilegedExceptionAction;
22+
import java.util.HashSet;
23+
import java.util.Set;
24+
25+
/** A serialization proxy for singleton objects */
26+
public final class ModuleSerializationProxy implements Serializable {
27+
private static final long serialVersionUID = 1L;
28+
private final Class<?> moduleClass;
29+
private static final ClassValue<Object> instances = new ClassValue<Object>() {
30+
@Override
31+
protected Object computeValue(Class<?> type) {
32+
try {
33+
return AccessController.doPrivileged((PrivilegedExceptionAction<Object>) () -> type.getField("MODULE$").get(null));
34+
} catch (PrivilegedActionException e) {
35+
return rethrowRuntime(e.getCause());
36+
}
37+
}
38+
};
39+
40+
private static Object rethrowRuntime(Throwable e) {
41+
Throwable cause = e.getCause();
42+
if (cause instanceof RuntimeException) throw (RuntimeException) cause;
43+
else throw new RuntimeException(cause);
44+
}
45+
46+
public ModuleSerializationProxy(Class<?> moduleClass) {
47+
this.moduleClass = moduleClass;
48+
}
49+
50+
@SuppressWarnings("unused")
51+
private Object readResolve() {
52+
return instances.get(moduleClass);
53+
}
54+
}

tests/run/literals.decompiled

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
object Test {
33
def αρετη: java.lang.String = "alpha rho epsilon tau eta"
44
case class GGG(i: scala.Int) {
5-
def αα(that: Test.GGG): scala.Int = GGG.this.i.+(that.i)
65
override def hashCode(): scala.Int = {
76
var acc: scala.Int = 767242539
87
acc = scala.runtime.Statics.mix(acc, GGG.this.i)
@@ -24,6 +23,7 @@ object Test {
2423
case _ =>
2524
throw new java.lang.IndexOutOfBoundsException(n.toString())
2625
}
26+
def αα(that: Test.GGG): scala.Int = GGG.this.i.+(that.i)
2727
}
2828
object GGG extends scala.Function1[scala.Int, Test.GGG]
2929
def check_success[a](name: scala.Predef.String, closure: => a, expected: a): scala.Unit = {
@@ -95,4 +95,4 @@ object Test {
9595
val ggg: scala.Int = Test.GGG.apply(1).αα(Test.GGG.apply(2))
9696
Test.check_success[scala.Int]("ggg == 3", ggg, 3)
9797
}
98-
}
98+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import java.io.File
2+
3+
object Module {
4+
val data = new Array[Byte](32 * 1024 * 1024)
5+
}
6+
7+
object Test {
8+
private val readResolve = classOf[scala.runtime.ModuleSerializationProxy].getDeclaredMethod("readResolve")
9+
readResolve.setAccessible(true)
10+
11+
val testClassesDir = new File(Module.getClass.getClassLoader.getResource("Module.class").toURI).getParentFile
12+
def main(args: Array[String]): Unit = {
13+
for (i <- 1 to 256) {
14+
// This would "java.lang.OutOfMemoryError: Java heap space" if ModuleSerializationProxy
15+
// prevented class unloading.
16+
deserializeDynamicLoadedClass()
17+
}
18+
}
19+
20+
def deserializeDynamicLoadedClass(): Unit = {
21+
val loader = new java.net.URLClassLoader(Array(testClassesDir.toURI.toURL), ClassLoader.getSystemClassLoader)
22+
val moduleClass = loader.loadClass("Module$")
23+
assert(moduleClass ne Module.getClass)
24+
val result = readResolve.invoke(new scala.runtime.ModuleSerializationProxy(moduleClass))
25+
assert(result.getClass == moduleClass)
26+
}
27+
}

tests/run/serialize.scala

+14
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,23 @@ object Test {
88
in.readObject.asInstanceOf[T]
99
}
1010

11+
object Foo extends Serializable {}
12+
13+
object Baz extends Serializable {
14+
private def writeReplace(): AnyRef = {
15+
this
16+
}
17+
}
18+
1119
def main(args: Array[String]): Unit = {
1220
val x: PartialFunction[Int, Int] = { case x => x + 1 }
1321
val adder = serializeDeserialize(x)
1422
assert(adder(1) == 2)
23+
24+
val foo = serializeDeserialize(Foo)
25+
assert(foo eq Foo)
26+
27+
val baz = serializeDeserialize(Baz)
28+
assert(baz ne Baz)
1529
}
1630
}

tests/run/tasty-extractors-2.check

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Type.SymRef(IsClassSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageSymb
4949
Term.Inlined(None, Nil, Term.Block(List(ClassDef("Foo", DefDef("<init>", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil)), Nil, None, List(DefDef("a", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.Int(0))))))), Term.Literal(Constant.Unit())))
5050
Type.SymRef(IsClassSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageSymbol(<scala>), NoPrefix())))
5151

52-
Term.Inlined(None, Nil, Term.Block(List(ClassDef("Foo", DefDef("<init>", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil), TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Product"), TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Serializable")), Nil, None, List(DefDef("productElementName", Nil, List(List(ValDef("x$1", TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Int"), None))), TypeTree.Select(Term.Select(Term.Ident("java"), "lang"), "String"), Some(Term.Match(Term.Ident("x$1"), List(CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Apply(Term.Ident("throw"), List(Term.Apply(Term.Select(Term.New(TypeTree.Select(Term.Select(Term.Ident("java"), "lang"), "IndexOutOfBoundsException")), "<init>"), List(Term.Apply(Term.Select(Term.Select(Term.Select(Term.Ident("java"), "lang"), "String"), "valueOf"), List(Term.Ident("x$1")))))))))))), DefDef("copy", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil))), DefDef("hashCode", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Literal(Constant.Int(394005536)))), DefDef("equals", Nil, List(List(ValDef("x$0", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.Apply(Term.Select(Term.This(Some(Id("Foo"))), "eq"), List(Term.TypeApply(Term.Select(Term.Ident("x$0"), "asInstanceOf"), List(TypeTree.Inferred())))), "||"), List(Term.Match(Term.Ident("x$0"), List(CaseDef(Pattern.Bind("x$0", Pattern.TypeTest(TypeTree.Inferred())), None, Term.Literal(Constant.Boolean(true))), CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Literal(Constant.Boolean(false))))))))), DefDef("toString", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Ident("_toString"), List(Term.This(Some(Id("Foo"))))))), DefDef("canEqual", Nil, List(List(ValDef("that", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.TypeApply(Term.Select(Term.Ident("that"), "isInstanceOf"), List(TypeTree.Inferred())))), DefDef("productArity", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.Int(0)))), DefDef("productPrefix", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.String("Foo")))), DefDef("productElement", Nil, List(List(ValDef("n", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Match(Term.Ident("n"), List(CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Apply(Term.Ident("throw"), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), List(Term.Apply(Term.Select(Term.Ident("n"), "toString"), Nil)))))))))))), ValDef("Foo", TypeTree.Ident("Foo$"), Some(Term.Apply(Term.Select(Term.New(TypeTree.Ident("Foo$")), "<init>"), Nil))), ClassDef("Foo$", DefDef("<init>", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil), TypeTree.Applied(TypeTree.Inferred(), List(TypeTree.Inferred()))), Nil, Some(ValDef("_", TypeTree.Singleton(Term.Ident("Foo")), None)), List(DefDef("apply", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil))), DefDef("unapply", Nil, List(List(ValDef("x$1", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Literal(Constant.Boolean(true))))))), Term.Literal(Constant.Unit())))
52+
Term.Inlined(None, Nil, Term.Block(List(ClassDef("Foo", DefDef("<init>", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil), TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Product"), TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Serializable")), Nil, None, List(DefDef("hashCode", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Literal(Constant.Int(394005536)))), DefDef("equals", Nil, List(List(ValDef("x$0", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.Apply(Term.Select(Term.This(Some(Id("Foo"))), "eq"), List(Term.TypeApply(Term.Select(Term.Ident("x$0"), "asInstanceOf"), List(TypeTree.Inferred())))), "||"), List(Term.Match(Term.Ident("x$0"), List(CaseDef(Pattern.Bind("x$0", Pattern.TypeTest(TypeTree.Inferred())), None, Term.Literal(Constant.Boolean(true))), CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Literal(Constant.Boolean(false))))))))), DefDef("toString", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Ident("_toString"), List(Term.This(Some(Id("Foo"))))))), DefDef("canEqual", Nil, List(List(ValDef("that", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.TypeApply(Term.Select(Term.Ident("that"), "isInstanceOf"), List(TypeTree.Inferred())))), DefDef("productArity", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.Int(0)))), DefDef("productPrefix", Nil, Nil, TypeTree.Inferred(), Some(Term.Literal(Constant.String("Foo")))), DefDef("productElement", Nil, List(List(ValDef("n", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Match(Term.Ident("n"), List(CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Apply(Term.Ident("throw"), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), List(Term.Apply(Term.Select(Term.Ident("n"), "toString"), Nil)))))))))), DefDef("productElementName", Nil, List(List(ValDef("x$1", TypeTree.Select(Term.Select(Term.Ident("_root_"), "scala"), "Int"), None))), TypeTree.Select(Term.Select(Term.Ident("java"), "lang"), "String"), Some(Term.Match(Term.Ident("x$1"), List(CaseDef(Pattern.Value(Term.Ident("_")), None, Term.Apply(Term.Ident("throw"), List(Term.Apply(Term.Select(Term.New(TypeTree.Select(Term.Select(Term.Ident("java"), "lang"), "IndexOutOfBoundsException")), "<init>"), List(Term.Apply(Term.Select(Term.Select(Term.Select(Term.Ident("java"), "lang"), "String"), "valueOf"), List(Term.Ident("x$1")))))))))))), DefDef("copy", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil))))), ValDef("Foo", TypeTree.Ident("Foo$"), Some(Term.Apply(Term.Select(Term.New(TypeTree.Ident("Foo$")), "<init>"), Nil))), ClassDef("Foo$", DefDef("<init>", Nil, List(Nil), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil), TypeTree.Applied(TypeTree.Inferred(), List(TypeTree.Inferred()))), Nil, Some(ValDef("_", TypeTree.Singleton(Term.Ident("Foo")), None)), List(DefDef("apply", Nil, List(Nil), TypeTree.Inferred(), Some(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil))), DefDef("unapply", Nil, List(List(ValDef("x$1", TypeTree.Inferred(), None))), TypeTree.Inferred(), Some(Term.Literal(Constant.Boolean(true))))))), Term.Literal(Constant.Unit())))
5353
Type.SymRef(IsClassSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageSymbol(<scala>), NoPrefix())))
5454

5555
Term.Inlined(None, Nil, Term.Block(List(ClassDef("Foo1", DefDef("<init>", Nil, List(List(ValDef("a", TypeTree.Ident("Int"), None))), TypeTree.Inferred(), None), List(Term.Apply(Term.Select(Term.New(TypeTree.Inferred()), "<init>"), Nil)), Nil, None, List(ValDef("a", TypeTree.Inferred(), None)))), Term.Literal(Constant.Unit())))

tests/run/valueclasses-pavlov.decompiled

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ final class Box1(val value: scala.Predef.String) extends scala.AnyVal {
1111
object Box1 extends scala.AnyRef
1212
/** Decompiled from out/runTestFromTasty/run/valueclasses-pavlov/Box2.tasty */
1313
final class Box2(val value: scala.Predef.String) extends scala.AnyVal with Foo {
14-
def box1(x: Box1): scala.Predef.String = "box1: ok"
15-
def box2(x: Box2): scala.Predef.String = "box2: ok"
1614
override def hashCode(): scala.Int = Box2.this.value.hashCode()
1715
override def equals(x$0: scala.Any): scala.Boolean = x$0 match {
1816
case x$0: Box2 @scala.unchecked =>
1917
Box2.this.value.==(x$0.value)
2018
case _ =>
2119
false
2220
}
21+
def box1(x: Box1): scala.Predef.String = "box1: ok"
22+
def box2(x: Box2): scala.Predef.String = "box2: ok"
2323
}
2424
object Box2 extends scala.AnyRef
2525
/** Decompiled from out/runTestFromTasty/run/valueclasses-pavlov/C.tasty */

0 commit comments

Comments
 (0)