Skip to content

Commit 3b8100e

Browse files
committed
Fix #1692: Null out fields after use in lazy initialization
Private fields that are only used during lazy val initialization can be assigned null once the lazy val is initialized. This is not just an optimization, but is needed for correctness to prevent memory leaks.
1 parent 07fa870 commit 3b8100e

File tree

5 files changed

+286
-13
lines changed

5 files changed

+286
-13
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class Compiler {
7373
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
7474
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
7575
new ClassOf, // Expand `Predef.classOf` calls.
76+
new CollectNullableFields, // Collect fields that can be null out after use in lazy initialization
7677
new RefChecks) :: // Various checks mostly related to abstract members and overriding
7778
List(new TryCatchPatterns, // Compile cases in try/catch
7879
new PatternMatcher, // Compile pattern matches
@@ -97,7 +98,7 @@ class Compiler {
9798
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
9899
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
99100
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
100-
new Mixin, // Expand trait fields and trait initializers
101+
new Mixin, // Expand trait fields and trait initializers
101102
new LazyVals, // Expand lazy vals
102103
new Memoize, // Add private fields to getters and setters
103104
new NonLocalReturns, // Expand non-local returns

compiler/src/dotty/tools/dotc/core/Phases.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ object Phases {
210210

211211
private[this] var myTyperPhase: Phase = _
212212
private[this] var myPicklerPhase: Phase = _
213+
private[this] var myCollectNullableFieldsPhase: Phase = _
213214
private[this] var myRefChecksPhase: Phase = _
214215
private[this] var myPatmatPhase: Phase = _
215216
private[this] var myElimRepeatedPhase: Phase = _
@@ -224,6 +225,7 @@ object Phases {
224225

225226
final def typerPhase = myTyperPhase
226227
final def picklerPhase = myPicklerPhase
228+
final def collectNullableFieldsPhase = myCollectNullableFieldsPhase
227229
final def refchecksPhase = myRefChecksPhase
228230
final def patmatPhase = myPatmatPhase
229231
final def elimRepeatedPhase = myElimRepeatedPhase
@@ -241,6 +243,7 @@ object Phases {
241243

242244
myTyperPhase = phaseOfClass(classOf[FrontEnd])
243245
myPicklerPhase = phaseOfClass(classOf[Pickler])
246+
myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields])
244247
myRefChecksPhase = phaseOfClass(classOf[RefChecks])
245248
myElimRepeatedPhase = phaseOfClass(classOf[ElimRepeated])
246249
myExtensionMethodsPhase = phaseOfClass(classOf[ExtensionMethods])
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package dotty.tools.dotc.transform
2+
3+
import dotty.tools.dotc.ast.tpd
4+
import dotty.tools.dotc.core.Contexts.Context
5+
import dotty.tools.dotc.core.Flags._
6+
import dotty.tools.dotc.core.Symbols.Symbol
7+
import dotty.tools.dotc.core.Types.{Type, ExprType}
8+
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
9+
import dotty.tools.dotc.transform.SymUtils._
10+
11+
import scala.collection.JavaConverters._
12+
import scala.collection.mutable
13+
14+
import java.util.IdentityHashMap
15+
16+
object CollectNullableFields {
17+
val name = "collectNullableFields"
18+
}
19+
20+
/** Collect fields that can be null out after use in lazy initialization.
21+
*
22+
* This information is used during lazy val transformation to assign null to private
23+
* fields that are only used within a lazy val initializer. This is not just an optimization,
24+
* but is needed for correctness to prevent memory leaks. E.g.
25+
*
26+
* {{{
27+
* class TestByNameLazy(byNameMsg: => String) {
28+
* lazy val byLazyValMsg = byNameMsg
29+
* }
30+
* }}}
31+
*
32+
* Here `byNameMsg` should be null out once `byLazyValMsg` is
33+
* initialised.
34+
*
35+
* A field is nullable if all the conditions below hold:
36+
* - is private
37+
* - is not lazy
38+
* - its type is nullable, or is an expression type (e.g. => Int)
39+
* - is on used in a lazy val initializer
40+
* - defined in the same class as the lazy val
41+
* - TODO from Scalac? from a non-trait class
42+
*/
43+
class CollectNullableFields extends MiniPhase {
44+
import tpd._
45+
46+
override def phaseName = CollectNullableFields.name
47+
48+
private[this] sealed trait FieldInfo
49+
private[this] case object NotNullable extends FieldInfo
50+
private[this] case class Nullable(by: Symbol) extends FieldInfo
51+
52+
/** Whether or not a field is nullable */
53+
private[this] var nullability: IdentityHashMap[Symbol, FieldInfo] = _
54+
55+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
56+
nullability = new IdentityHashMap
57+
ctx
58+
}
59+
60+
private def recordUse(tree: Tree)(implicit ctx: Context): Tree = {
61+
val sym = tree.symbol
62+
63+
def isNullableType(tpe: Type) =
64+
tpe.isInstanceOf[ExprType] ||
65+
tpe.widenDealias.typeSymbol.isNullableClass
66+
val isNullablePrivateField = sym.isField && sym.is(Private, butNot = Lazy) && isNullableType(sym.info)
67+
68+
if (isNullablePrivateField)
69+
nullability.get(sym) match {
70+
case Nullable(from) if from != ctx.owner => // used in multiple lazy val initializers
71+
nullability.put(sym, NotNullable)
72+
case null => // not in the map
73+
val from = ctx.owner
74+
val isNullable =
75+
from.is(Lazy) && from.isField && // used in lazy field initializer
76+
from.owner.eq(sym.owner) // lazy val and field in the same class
77+
val info = if (isNullable) Nullable(from) else NotNullable
78+
nullability.put(sym, info)
79+
case _ =>
80+
// Do nothing for:
81+
// - NotNullable
82+
// - Nullable(ctx.owner)
83+
}
84+
85+
tree
86+
}
87+
88+
override def transformIdent(tree: Ident)(implicit ctx: Context) =
89+
recordUse(tree)
90+
91+
override def transformSelect(tree: Select)(implicit ctx: Context) =
92+
recordUse(tree)
93+
94+
/** Map lazy values to the fields they should null after initialization. */
95+
def lazyValNullables(implicit ctx: Context): Map[Symbol, List[Symbol]] = {
96+
val result = new mutable.HashMap[Symbol, mutable.ListBuffer[Symbol]]
97+
98+
nullability.forEach {
99+
case (sym, Nullable(from)) =>
100+
val bldr = result.getOrElseUpdate(from, new mutable.ListBuffer)
101+
bldr += sym
102+
case _ =>
103+
}
104+
105+
result.mapValues(_.toList).toMap
106+
}
107+
}

compiler/src/dotty/tools/dotc/transform/LazyVals.scala

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
3939

4040
/** List of names of phases that should have finished processing of tree
4141
* before this phase starts processing same tree */
42-
override def runsAfter = Set(Mixin.name)
42+
override def runsAfter = Set(Mixin.name, CollectNullableFields.name)
4343

4444
override def changesMembers = true // the phase adds lazy val accessors
4545

@@ -50,6 +50,18 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
5050

5151
val containerFlagsMask = Flags.Method | Flags.Lazy | Flags.Accessor | Flags.Module
5252

53+
/** A map of lazy values to the fields they should null after initialization. */
54+
private[this] var lazyValNullables: Map[Symbol, List[Symbol]] = _
55+
private def nullableFor(sym: Symbol)(implicit ctx: Context) =
56+
if (sym.is(Flags.Module)) Nil
57+
else lazyValNullables.getOrElse(sym, Nil)
58+
59+
60+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
61+
lazyValNullables = ctx.collectNullableFieldsPhase.asInstanceOf[CollectNullableFields].lazyValNullables
62+
ctx
63+
}
64+
5365
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree =
5466
transformLazyVal(tree)
5567

@@ -150,7 +162,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
150162
val initBody =
151163
adaptToType(
152164
ref(holderSymbol).select(defn.Object_synchronized).appliedTo(
153-
adaptToType(mkNonThreadSafeDef(result, flag, initer), defn.ObjectType)),
165+
adaptToType(mkNonThreadSafeDef(result, flag, initer, nullables = Nil), defn.ObjectType)),
154166
tpe)
155167
val initTree = DefDef(initSymbol, initBody)
156168
val holderTree = ValDef(holderSymbol, New(holderImpl.typeRef, List()))
@@ -176,37 +188,51 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
176188
holders:::stats
177189
}
178190

191+
private def nullOut(nullables: List[Symbol])(implicit ctx: Context): List[Tree] = {
192+
val nullConst = Literal(Constants.Constant(null))
193+
nullables.map { sym =>
194+
val field = if (sym.isGetter) sym.field else sym
195+
assert(field.isField)
196+
field.setFlag(Flags.Mutable)
197+
ref(field).becomes(nullConst)
198+
}
199+
}
200+
179201
/** Create non-threadsafe lazy accessor equivalent to such code
180202
* def methodSymbol() = {
181203
* if (flag) target
182204
* else {
183205
* target = rhs
184206
* flag = true
207+
* nullable = null
185208
* target
186209
* }
187210
* }
188211
*/
189212

190-
def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree)(implicit ctx: Context) = {
213+
def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = {
191214
val setFlag = flag.becomes(Literal(Constants.Constant(true)))
192-
val setTargets = if (isWildcardArg(rhs)) Nil else target.becomes(rhs) :: Nil
193-
val init = Block(setFlag :: setTargets, target.ensureApplied)
215+
val setNullables = nullOut(nullables)
216+
val setTargetAndNullable = if (isWildcardArg(rhs)) setNullables else target.becomes(rhs) :: setNullables
217+
val init = Block(setFlag :: setTargetAndNullable, target.ensureApplied)
194218
If(flag.ensureApplied, target.ensureApplied, init)
195219
}
196220

197221
/** Create non-threadsafe lazy accessor for not-nullable types equivalent to such code
198222
* def methodSymbol() = {
199223
* if (target eq null) {
200224
* target = rhs
225+
* nullable = null
201226
* target
202227
* } else target
203228
* }
204229
*/
205-
def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree)(implicit ctx: Context) = {
230+
def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = {
206231
val cond = ref(target).select(nme.eq).appliedTo(Literal(Constant(null)))
207232
val exp = ref(target)
208233
val setTarget = exp.becomes(rhs)
209-
val init = Block(List(setTarget), exp)
234+
val setNullables = nullOut(nullables)
235+
val init = Block(setTarget :: setNullables, exp)
210236
If(cond, init, exp)
211237
}
212238

@@ -222,14 +248,14 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
222248

223249
val containerTree = ValDef(containerSymbol, defaultValue(tpe))
224250
if (x.tpe.isNotNull && tpe <:< defn.ObjectType) { // can use 'null' value instead of flag
225-
val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs))
251+
val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs, nullableFor(x.symbol)))
226252
Thicket(containerTree, slowPath)
227253
}
228254
else {
229255
val flagName = LazyBitMapName.fresh(x.name.asTermName)
230256
val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this)
231257
val flag = ValDef(flagSymbol, Literal(Constants.Constant(false)))
232-
val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs))
258+
val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs, nullableFor(x.symbol)))
233259
Thicket(containerTree, flag, slowPath)
234260
}
235261
}
@@ -263,10 +289,23 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
263289
* result = $target
264290
* }
265291
* }
292+
* nullable = null
266293
* result
267294
* }
268295
*/
269-
def mkThreadSafeDef(methodSymbol: TermSymbol, claz: ClassSymbol, ord: Int, target: Symbol, rhs: Tree, tp: Types.Type, offset: Tree, getFlag: Tree, stateMask: Tree, casFlag: Tree, setFlagState: Tree, waitOnLock: Tree)(implicit ctx: Context) = {
296+
def mkThreadSafeDef(methodSymbol: TermSymbol,
297+
claz: ClassSymbol,
298+
ord: Int,
299+
target: Symbol,
300+
rhs: Tree,
301+
tp: Types.Type,
302+
offset: Tree,
303+
getFlag: Tree,
304+
stateMask: Tree,
305+
casFlag: Tree,
306+
setFlagState: Tree,
307+
waitOnLock: Tree,
308+
nullables: List[Symbol])(implicit ctx: Context) = {
270309
val initState = Literal(Constants.Constant(0))
271310
val computeState = Literal(Constants.Constant(1))
272311
val notifyState = Literal(Constants.Constant(2))
@@ -330,7 +369,8 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
330369

331370
val whileBody = List(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)), cases)
332371
val cycle = WhileDo(methodSymbol, whileCond, whileBody)
333-
DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: Nil, ref(resultSymbol)))
372+
val setNullables = nullOut(nullables)
373+
DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: setNullables, ref(resultSymbol)))
334374
}
335375

336376
def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context) = {
@@ -390,8 +430,9 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
390430
val wait = Select(ref(helperModule), lazyNme.RLazyVals.wait4Notification)
391431
val state = Select(ref(helperModule), lazyNme.RLazyVals.state)
392432
val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas)
433+
val nullables = nullableFor(x.symbol)
393434

394-
val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait)
435+
val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait, nullables)
395436
if (flag eq EmptyTree)
396437
Thicket(containerTree, accessor)
397438
else Thicket(containerTree, flag, accessor)

0 commit comments

Comments
 (0)