Skip to content

Commit b4bcd05

Browse files
committed
Fix #1692: Null out fields after use in lazy initialization
Private fields that are only used during lzyy val initialization can be assigned null once the lazy val is initialized. This is not just an optimization, but is needed for correctness to prevent memory leaks.
1 parent 07fa870 commit b4bcd05

File tree

4 files changed

+150
-13
lines changed

4 files changed

+150
-13
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class Compiler {
7373
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
7474
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
7575
new ClassOf, // Expand `Predef.classOf` calls.
76+
new CollectNullableFields, // Collect fields that can be null out after use in lazy initialization
7677
new RefChecks) :: // Various checks mostly related to abstract members and overriding
7778
List(new TryCatchPatterns, // Compile cases in try/catch
7879
new PatternMatcher, // Compile pattern matches
@@ -97,7 +98,7 @@ class Compiler {
9798
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
9899
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
99100
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
100-
new Mixin, // Expand trait fields and trait initializers
101+
new Mixin, // Expand trait fields and trait initializers
101102
new LazyVals, // Expand lazy vals
102103
new Memoize, // Add private fields to getters and setters
103104
new NonLocalReturns, // Expand non-local returns

compiler/src/dotty/tools/dotc/core/Phases.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ object Phases {
210210

211211
private[this] var myTyperPhase: Phase = _
212212
private[this] var myPicklerPhase: Phase = _
213+
private[this] var myCollectNullableFieldsPhase: Phase = _
213214
private[this] var myRefChecksPhase: Phase = _
214215
private[this] var myPatmatPhase: Phase = _
215216
private[this] var myElimRepeatedPhase: Phase = _
@@ -224,6 +225,7 @@ object Phases {
224225

225226
final def typerPhase = myTyperPhase
226227
final def picklerPhase = myPicklerPhase
228+
final def collectNullableFieldsPhase = myCollectNullableFieldsPhase
227229
final def refchecksPhase = myRefChecksPhase
228230
final def patmatPhase = myPatmatPhase
229231
final def elimRepeatedPhase = myElimRepeatedPhase
@@ -241,6 +243,7 @@ object Phases {
241243

242244
myTyperPhase = phaseOfClass(classOf[FrontEnd])
243245
myPicklerPhase = phaseOfClass(classOf[Pickler])
246+
myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields])
244247
myRefChecksPhase = phaseOfClass(classOf[RefChecks])
245248
myElimRepeatedPhase = phaseOfClass(classOf[ElimRepeated])
246249
myExtensionMethodsPhase = phaseOfClass(classOf[ExtensionMethods])
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package dotty.tools.dotc.transform
2+
3+
import dotty.tools.dotc.ast.tpd
4+
import dotty.tools.dotc.core.Contexts.Context
5+
import dotty.tools.dotc.core.Flags._
6+
import dotty.tools.dotc.core.Symbols.Symbol
7+
import dotty.tools.dotc.core.Types.{Type, ExprType}
8+
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
9+
import dotty.tools.dotc.transform.SymUtils._
10+
11+
import scala.collection.JavaConverters._
12+
import scala.collection.mutable
13+
14+
import java.util.IdentityHashMap
15+
16+
object CollectNullableFields {
17+
val name = "collectNullableFields"
18+
}
19+
20+
/** Collect fields that can be null out after use in lazy initialization.
21+
*
22+
* This information is used during lazy val transformation to assign null to private
23+
* fields that are only used within a lazy val initializer. This is not just an optimization,
24+
* but is needed for correctness to prevent memory leaks. E.g.
25+
*
26+
* {{{
27+
* class TestByNameLazy(byNameMsg: => String) {
28+
* lazy val byLazyValMsg = byNameMsg
29+
* }
30+
* }}}
31+
*
32+
* Here `byNameMsg` should be null out once `byLazyValMsg` is
33+
* initialised.
34+
*/
35+
class CollectNullableFields extends MiniPhase {
36+
import tpd._
37+
38+
override def phaseName = CollectNullableFields.name
39+
40+
private[this] sealed trait FieldInfo
41+
private[this] case object NotNullable extends FieldInfo
42+
private[this] case class Nullable(by: Symbol) extends FieldInfo
43+
44+
/** Whether or not a field is nullable */
45+
private[this] var nullability: IdentityHashMap[Symbol, FieldInfo] = _
46+
47+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
48+
nullability = new IdentityHashMap
49+
ctx
50+
}
51+
52+
private def recordUse(tree: Tree)(implicit ctx: Context): Tree = {
53+
val sym = tree.symbol
54+
55+
def isNullableType(tpe: Type) =
56+
tpe.isInstanceOf[ExprType] ||
57+
tpe.widenDealias.typeSymbol.isNullableClass
58+
val isNullablePrivateField = sym.isField && sym.is(Private, butNot = Lazy) && isNullableType(sym.info)
59+
60+
if (isNullablePrivateField)
61+
nullability.get(sym) match {
62+
case Nullable(from) if from != ctx.owner => // used in multiple lazy val initializers
63+
nullability.put(sym, NotNullable)
64+
case null => // not in the map
65+
val from = ctx.owner
66+
val inLazyValInitializer = from.is(Lazy, butNot = Module)
67+
val info = if (inLazyValInitializer) Nullable(from) else NotNullable
68+
nullability.put(sym, info)
69+
case _ =>
70+
// Do nothing for:
71+
// - NotNullable
72+
// - Nullable(ctx.owner)
73+
}
74+
75+
tree
76+
}
77+
78+
override def transformIdent(tree: Ident)(implicit ctx: Context) =
79+
recordUse(tree)
80+
81+
override def transformSelect(tree: Select)(implicit ctx: Context) =
82+
recordUse(tree)
83+
84+
/** Map lazy values to the fields they should null after initialization. */
85+
def lazyValNullables(implicit ctx: Context): Map[Symbol, List[Symbol]] = {
86+
val result = new mutable.HashMap[Symbol, mutable.ListBuffer[Symbol]]
87+
88+
nullability.forEach {
89+
case (sym, Nullable(from)) =>
90+
val bldr = result.getOrElseUpdate(from, new mutable.ListBuffer)
91+
bldr += sym
92+
case _ =>
93+
}
94+
95+
result.mapValues(_.toList).toMap
96+
}
97+
}

compiler/src/dotty/tools/dotc/transform/LazyVals.scala

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
3939

4040
/** List of names of phases that should have finished processing of tree
4141
* before this phase starts processing same tree */
42-
override def runsAfter = Set(Mixin.name)
42+
override def runsAfter = Set(Mixin.name, CollectNullableFields.name)
4343

4444
override def changesMembers = true // the phase adds lazy val accessors
4545

@@ -50,6 +50,15 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
5050

5151
val containerFlagsMask = Flags.Method | Flags.Lazy | Flags.Accessor | Flags.Module
5252

53+
/** A map of lazy values to the fields they should null after initialization. */
54+
private[this] var lazyValNullables: Map[Symbol, List[Symbol]] = _
55+
private def nullableFor(sym: Symbol) = lazyValNullables.getOrElse(sym, Nil)
56+
57+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
58+
lazyValNullables = ctx.collectNullableFieldsPhase.asInstanceOf[CollectNullableFields].lazyValNullables
59+
ctx
60+
}
61+
5362
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree =
5463
transformLazyVal(tree)
5564

@@ -150,7 +159,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
150159
val initBody =
151160
adaptToType(
152161
ref(holderSymbol).select(defn.Object_synchronized).appliedTo(
153-
adaptToType(mkNonThreadSafeDef(result, flag, initer), defn.ObjectType)),
162+
adaptToType(mkNonThreadSafeDef(result, flag, initer, nullableFor(x.symbol)), defn.ObjectType)),
154163
tpe)
155164
val initTree = DefDef(initSymbol, initBody)
156165
val holderTree = ValDef(holderSymbol, New(holderImpl.typeRef, List()))
@@ -176,37 +185,50 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
176185
holders:::stats
177186
}
178187

188+
private def nullOut(nullables: List[Symbol])(implicit ctx: Context): List[Tree] = {
189+
val nullConst = Literal(Constants.Constant(null))
190+
nullables.map { sym =>
191+
val field = if (sym.isGetter) sym.field else sym
192+
field.setFlag(Flags.Mutable)
193+
ref(field).becomes(nullConst)
194+
}
195+
}
196+
179197
/** Create non-threadsafe lazy accessor equivalent to such code
180198
* def methodSymbol() = {
181199
* if (flag) target
182200
* else {
183201
* target = rhs
184202
* flag = true
203+
* nullable = null
185204
* target
186205
* }
187206
* }
188207
*/
189208

190-
def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree)(implicit ctx: Context) = {
209+
def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = {
191210
val setFlag = flag.becomes(Literal(Constants.Constant(true)))
192-
val setTargets = if (isWildcardArg(rhs)) Nil else target.becomes(rhs) :: Nil
193-
val init = Block(setFlag :: setTargets, target.ensureApplied)
211+
val setNullables = nullOut(nullables)
212+
val setTargetAndNullable = if (isWildcardArg(rhs)) setNullables else target.becomes(rhs) :: setNullables
213+
val init = Block(setFlag :: setTargetAndNullable, target.ensureApplied)
194214
If(flag.ensureApplied, target.ensureApplied, init)
195215
}
196216

197217
/** Create non-threadsafe lazy accessor for not-nullable types equivalent to such code
198218
* def methodSymbol() = {
199219
* if (target eq null) {
200220
* target = rhs
221+
* nullable = null
201222
* target
202223
* } else target
203224
* }
204225
*/
205-
def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree)(implicit ctx: Context) = {
226+
def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = {
206227
val cond = ref(target).select(nme.eq).appliedTo(Literal(Constant(null)))
207228
val exp = ref(target)
208229
val setTarget = exp.becomes(rhs)
209-
val init = Block(List(setTarget), exp)
230+
val setNullables = nullOut(nullables)
231+
val init = Block(setTarget :: setNullables, exp)
210232
If(cond, init, exp)
211233
}
212234

@@ -222,14 +244,14 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
222244

223245
val containerTree = ValDef(containerSymbol, defaultValue(tpe))
224246
if (x.tpe.isNotNull && tpe <:< defn.ObjectType) { // can use 'null' value instead of flag
225-
val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs))
247+
val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs, nullableFor(x.symbol)))
226248
Thicket(containerTree, slowPath)
227249
}
228250
else {
229251
val flagName = LazyBitMapName.fresh(x.name.asTermName)
230252
val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this)
231253
val flag = ValDef(flagSymbol, Literal(Constants.Constant(false)))
232-
val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs))
254+
val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs, nullableFor(x.symbol)))
233255
Thicket(containerTree, flag, slowPath)
234256
}
235257
}
@@ -263,10 +285,23 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
263285
* result = $target
264286
* }
265287
* }
288+
* nullable = null
266289
* result
267290
* }
268291
*/
269-
def mkThreadSafeDef(methodSymbol: TermSymbol, claz: ClassSymbol, ord: Int, target: Symbol, rhs: Tree, tp: Types.Type, offset: Tree, getFlag: Tree, stateMask: Tree, casFlag: Tree, setFlagState: Tree, waitOnLock: Tree)(implicit ctx: Context) = {
292+
def mkThreadSafeDef(methodSymbol: TermSymbol,
293+
claz: ClassSymbol,
294+
ord: Int,
295+
target: Symbol,
296+
rhs: Tree,
297+
tp: Types.Type,
298+
offset: Tree,
299+
getFlag: Tree,
300+
stateMask: Tree,
301+
casFlag: Tree,
302+
setFlagState: Tree,
303+
waitOnLock: Tree,
304+
nullables: List[Symbol])(implicit ctx: Context) = {
270305
val initState = Literal(Constants.Constant(0))
271306
val computeState = Literal(Constants.Constant(1))
272307
val notifyState = Literal(Constants.Constant(2))
@@ -330,7 +365,8 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
330365

331366
val whileBody = List(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)), cases)
332367
val cycle = WhileDo(methodSymbol, whileCond, whileBody)
333-
DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: Nil, ref(resultSymbol)))
368+
val setNullables = nullOut(nullables)
369+
DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: setNullables, ref(resultSymbol)))
334370
}
335371

336372
def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context) = {
@@ -391,7 +427,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
391427
val state = Select(ref(helperModule), lazyNme.RLazyVals.state)
392428
val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas)
393429

394-
val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait)
430+
val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait, nullableFor(x.symbol))
395431
if (flag eq EmptyTree)
396432
Thicket(containerTree, accessor)
397433
else Thicket(containerTree, flag, accessor)

0 commit comments

Comments
 (0)