diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 94779bca680c..596b1d6bf45f 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -87,6 +87,7 @@ class Compiler { new InterceptedMethods, // Special handling of `==`, `|=`, `getClass` methods new Getters, // Replace non-private vals and vars with getter defs (fields are added later) new ElimByName, // Expand by-name parameter references + new CollectNullableFields, // Collect fields that can be nulled out after use in lazy initialization new ElimOuterSelect, // Expand outer selections new AugmentScala2Traits, // Expand traits defined in Scala 2.x to simulate old-style rewritings new ResolveSuper, // Implement super accessors and add forwarders to trait methods @@ -97,7 +98,7 @@ class Compiler { List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations - new Mixin, // Expand trait fields and trait initializers + new Mixin, // Expand trait fields and trait initializers new LazyVals, // Expand lazy vals new Memoize, // Add private fields to getters and setters new NonLocalReturns, // Expand non-local returns diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 239e8a4fe9d1..3894ef445c47 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -211,6 +211,7 @@ object Phases { private[this] var myTyperPhase: Phase = _ private[this] var mySbtExtractDependenciesPhase: Phase = _ private[this] var myPicklerPhase: Phase = _ + private[this] var myCollectNullableFieldsPhase: Phase = _ private[this] var myRefChecksPhase: Phase = _ private[this] var myPatmatPhase: Phase = _ private[this] var myElimRepeatedPhase: Phase = _ @@ -226,6 +227,7 @@ object Phases { final def typerPhase = myTyperPhase final def sbtExtractDependenciesPhase = mySbtExtractDependenciesPhase final def picklerPhase = myPicklerPhase + final def collectNullableFieldsPhase = myCollectNullableFieldsPhase final def refchecksPhase = myRefChecksPhase final def patmatPhase = myPatmatPhase final def elimRepeatedPhase = myElimRepeatedPhase @@ -244,6 +246,7 @@ object Phases { myTyperPhase = phaseOfClass(classOf[FrontEnd]) mySbtExtractDependenciesPhase = phaseOfClass(classOf[sbt.ExtractDependencies]) myPicklerPhase = phaseOfClass(classOf[Pickler]) + myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields]) myRefChecksPhase = phaseOfClass(classOf[RefChecks]) myElimRepeatedPhase = phaseOfClass(classOf[ElimRepeated]) myExtensionMethodsPhase = phaseOfClass(classOf[ExtensionMethods]) diff --git a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala new file mode 100644 index 000000000000..201ae4e21def --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala @@ -0,0 +1,112 @@ +package dotty.tools.dotc.transform + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags._ +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.Types.{Type, ExprType} +import dotty.tools.dotc.transform.MegaPhase.MiniPhase +import dotty.tools.dotc.transform.SymUtils._ + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import java.util.IdentityHashMap + +object CollectNullableFields { + val name = "collectNullableFields" +} + +/** Collect fields that can be nulled out after use in lazy initialization. + * + * This information is used during lazy val transformation to assign null to private + * fields that are only used within a lazy val initializer. This is not just an optimization, + * but is needed for correctness to prevent memory leaks. E.g. + * + * ```scala + * class TestByNameLazy(byNameMsg: => String) { + * lazy val byLazyValMsg = byNameMsg + * } + * ``` + * + * Here `byNameMsg` should be null out once `byLazyValMsg` is + * initialised. + * + * A field is nullable if all the conditions below hold: + * - belongs to a non trait-class + * - is private[this] + * - is not lazy + * - its type is nullable + * - is only used in a lazy val initializer + * - defined in the same class as the lazy val + */ +class CollectNullableFields extends MiniPhase { + import tpd._ + + override def phaseName = CollectNullableFields.name + + /** Running after `ElimByName` to see by names as nullable types. */ + override def runsAfter = Set(ElimByName.name) + + private[this] sealed trait FieldInfo + private[this] case object NotNullable extends FieldInfo + private[this] case class Nullable(by: Symbol) extends FieldInfo + + /** Whether or not a field is nullable */ + private[this] var nullability: IdentityHashMap[Symbol, FieldInfo] = _ + + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { + if (nullability == null) nullability = new IdentityHashMap + ctx + } + + private def recordUse(tree: Tree)(implicit ctx: Context): Tree = { + val sym = tree.symbol + val isNullablePrivateField = + sym.isField && + !sym.is(Lazy) && + !sym.owner.is(Trait) && + sym.initial.is(PrivateLocal) && + sym.info.widenDealias.typeSymbol.isNullableClass + + if (isNullablePrivateField) + nullability.get(sym) match { + case Nullable(from) if from != ctx.owner => // used in multiple lazy val initializers + nullability.put(sym, NotNullable) + case null => // not in the map + val from = ctx.owner + val isNullable = + from.is(Lazy, butNot = Module) && // is lazy val + from.owner.isClass && // is field + from.owner.eq(sym.owner) // is lazy val and field defined in the same class + val info = if (isNullable) Nullable(from) else NotNullable + nullability.put(sym, info) + case _ => + // Do nothing for: + // - NotNullable + // - Nullable(ctx.owner) + } + + tree + } + + override def transformIdent(tree: Ident)(implicit ctx: Context) = + recordUse(tree) + + override def transformSelect(tree: Select)(implicit ctx: Context) = + recordUse(tree) + + /** Map lazy values to the fields they should null after initialization. */ + def lazyValNullables(implicit ctx: Context): IdentityHashMap[Symbol, mutable.ListBuffer[Symbol]] = { + val result = new IdentityHashMap[Symbol, mutable.ListBuffer[Symbol]] + + nullability.forEach { + case (sym, Nullable(from)) => + val bldr = result.computeIfAbsent(from, _ => new mutable.ListBuffer) + bldr += sym + case _ => + } + + result + } +} diff --git a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala index 620bd53bc948..7e5cc6a277d9 100644 --- a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala +++ b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala @@ -26,6 +26,8 @@ import dotty.tools.dotc.core.SymDenotations.SymDenotation import dotty.tools.dotc.core.DenotTransformers.{SymTransformer, IdentityDenotTransformer, DenotTransformer} import Erasure.Boxing.adaptToType +import java.util.IdentityHashMap + class LazyVals extends MiniPhase with IdentityDenotTransformer { import LazyVals._ import tpd._ @@ -39,7 +41,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { /** List of names of phases that should have finished processing of tree * before this phase starts processing same tree */ - override def runsAfter = Set(Mixin.name) + override def runsAfter = Set(Mixin.name, CollectNullableFields.name) override def changesMembers = true // the phase adds lazy val accessors @@ -50,6 +52,22 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val containerFlagsMask = Flags.Method | Flags.Lazy | Flags.Accessor | Flags.Module + /** A map of lazy values to the fields they should null after initialization. */ + private[this] var lazyValNullables: IdentityHashMap[Symbol, mutable.ListBuffer[Symbol]] = _ + private def nullableFor(sym: Symbol)(implicit ctx: Context) = { + // optimisation: value only used once, we can remove the value from the map + val nullables = lazyValNullables.remove(sym) + if (nullables == null) Nil + else nullables.toList + } + + + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { + if (lazyValNullables == null) + lazyValNullables = ctx.collectNullableFieldsPhase.asInstanceOf[CollectNullableFields].lazyValNullables + ctx + } + override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = transformLazyVal(tree) @@ -117,51 +135,51 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { Thicket(field, getter) } - /** Replace a local lazy val inside a method, - * with a LazyHolder from - * dotty.runtime(eg dotty.runtime.LazyInt) - */ - def transformLocalDef(x: ValOrDefDef)(implicit ctx: Context) = { - val valueInitter = x.rhs - val xname = x.name.asTermName - val holderName = LazyLocalName.fresh(xname) - val initName = LazyLocalInitName.fresh(xname) - val tpe = x.tpe.widen.resultType.widen - - val holderType = - if (tpe isRef defn.IntClass) "LazyInt" - else if (tpe isRef defn.LongClass) "LazyLong" - else if (tpe isRef defn.BooleanClass) "LazyBoolean" - else if (tpe isRef defn.FloatClass) "LazyFloat" - else if (tpe isRef defn.DoubleClass) "LazyDouble" - else if (tpe isRef defn.ByteClass) "LazyByte" - else if (tpe isRef defn.CharClass) "LazyChar" - else if (tpe isRef defn.ShortClass) "LazyShort" - else "LazyRef" - - - val holderImpl = ctx.requiredClass("dotty.runtime." + holderType) - - val holderSymbol = ctx.newSymbol(x.symbol.owner, holderName, containerFlags, holderImpl.typeRef, coord = x.pos) - val initSymbol = ctx.newSymbol(x.symbol.owner, initName, initFlags, MethodType(Nil, tpe), coord = x.pos) - val result = ref(holderSymbol).select(lazyNme.value).withPos(x.pos) - val flag = ref(holderSymbol).select(lazyNme.initialized) - val initer = valueInitter.changeOwnerAfter(x.symbol, initSymbol, this) - val initBody = - adaptToType( - ref(holderSymbol).select(defn.Object_synchronized).appliedTo( - adaptToType(mkNonThreadSafeDef(result, flag, initer), defn.ObjectType)), - tpe) - val initTree = DefDef(initSymbol, initBody) - val holderTree = ValDef(holderSymbol, New(holderImpl.typeRef, List())) - val methodBody = tpd.If(flag.ensureApplied, - result.ensureApplied, - ref(initSymbol).ensureApplied).ensureConforms(tpe) - - val methodTree = DefDef(x.symbol.asTerm, methodBody) - ctx.debuglog(s"found a lazy val ${x.show},\nrewrote with ${holderTree.show}") - Thicket(holderTree, initTree, methodTree) - } + /** Replace a local lazy val inside a method, + * with a LazyHolder from + * dotty.runtime(eg dotty.runtime.LazyInt) + */ + def transformLocalDef(x: ValOrDefDef)(implicit ctx: Context) = { + val valueInitter = x.rhs + val xname = x.name.asTermName + val holderName = LazyLocalName.fresh(xname) + val initName = LazyLocalInitName.fresh(xname) + val tpe = x.tpe.widen.resultType.widen + + val holderType = + if (tpe isRef defn.IntClass) "LazyInt" + else if (tpe isRef defn.LongClass) "LazyLong" + else if (tpe isRef defn.BooleanClass) "LazyBoolean" + else if (tpe isRef defn.FloatClass) "LazyFloat" + else if (tpe isRef defn.DoubleClass) "LazyDouble" + else if (tpe isRef defn.ByteClass) "LazyByte" + else if (tpe isRef defn.CharClass) "LazyChar" + else if (tpe isRef defn.ShortClass) "LazyShort" + else "LazyRef" + + + val holderImpl = ctx.requiredClass("dotty.runtime." + holderType) + + val holderSymbol = ctx.newSymbol(x.symbol.owner, holderName, containerFlags, holderImpl.typeRef, coord = x.pos) + val initSymbol = ctx.newSymbol(x.symbol.owner, initName, initFlags, MethodType(Nil, tpe), coord = x.pos) + val result = ref(holderSymbol).select(lazyNme.value).withPos(x.pos) + val flag = ref(holderSymbol).select(lazyNme.initialized) + val initer = valueInitter.changeOwnerAfter(x.symbol, initSymbol, this) + val initBody = + adaptToType( + ref(holderSymbol).select(defn.Object_synchronized).appliedTo( + adaptToType(mkNonThreadSafeDef(result, flag, initer, nullables = Nil), defn.ObjectType)), + tpe) + val initTree = DefDef(initSymbol, initBody) + val holderTree = ValDef(holderSymbol, New(holderImpl.typeRef, List())) + val methodBody = tpd.If(flag.ensureApplied, + result.ensureApplied, + ref(initSymbol).ensureApplied).ensureConforms(tpe) + + val methodTree = DefDef(x.symbol.asTerm, methodBody) + ctx.debuglog(s"found a lazy val ${x.show},\nrewrote with ${holderTree.show}") + Thicket(holderTree, initTree, methodTree) + } override def transformStats(trees: List[tpd.Tree])(implicit ctx: Context): List[tpd.Tree] = { @@ -176,226 +194,259 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { holders:::stats } + private def nullOut(nullables: List[Symbol])(implicit ctx: Context): List[Tree] = { + val nullConst = Literal(Constants.Constant(null)) + nullables.map { field => + assert(field.isField) + field.setFlag(Flags.Mutable) + ref(field).becomes(nullConst) + } + } + /** Create non-threadsafe lazy accessor equivalent to such code - * def methodSymbol() = { - * if (flag) target - * else { - * target = rhs - * flag = true - * target - * } - * } - */ - - def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree)(implicit ctx: Context) = { - val setFlag = flag.becomes(Literal(Constants.Constant(true))) - val setTargets = if (isWildcardArg(rhs)) Nil else target.becomes(rhs) :: Nil - val init = Block(setFlag :: setTargets, target.ensureApplied) - If(flag.ensureApplied, target.ensureApplied, init) + * ``` + * def methodSymbol() = { + * if (flag) target + * else { + * target = rhs + * flag = true + * nullable = null + * target + * } + * } + * } + * ``` + */ + def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = { + val setFlag = flag.becomes(Literal(Constants.Constant(true))) + val setNullables = nullOut(nullables) + val setTargetAndNullable = if (isWildcardArg(rhs)) setNullables else target.becomes(rhs) :: setNullables + val init = Block(setFlag :: setTargetAndNullable, target.ensureApplied) + If(flag.ensureApplied, target.ensureApplied, init) + } + + /** Create non-threadsafe lazy accessor for not-nullable types equivalent to such code + * ``` + * def methodSymbol() = { + * if (target eq null) { + * target = rhs + * nullable = null + * target + * } else target + * } + * ``` + */ + def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = { + val cond = ref(target).select(nme.eq).appliedTo(Literal(Constant(null))) + val exp = ref(target) + val setTarget = exp.becomes(rhs) + val setNullables = nullOut(nullables) + val init = Block(setTarget :: setNullables, exp) + If(cond, init, exp) + } + + def transformMemberDefNonVolatile(x: ValOrDefDef)(implicit ctx: Context) = { + val claz = x.symbol.owner.asClass + val tpe = x.tpe.widen.resultType.widen + assert(!(x.symbol is Flags.Mutable)) + val containerName = LazyLocalName.fresh(x.name.asTermName) + val containerSymbol = ctx.newSymbol(claz, containerName, + x.symbol.flags &~ containerFlagsMask | containerFlags | Flags.Private, + tpe, coord = x.symbol.coord + ).enteredAfter(this) + + val containerTree = ValDef(containerSymbol, defaultValue(tpe)) + if (x.tpe.isNotNull && tpe <:< defn.ObjectType) { // can use 'null' value instead of flag + val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs, nullableFor(x.symbol))) + Thicket(containerTree, slowPath) } + else { + val flagName = LazyBitMapName.fresh(x.name.asTermName) + val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this) + val flag = ValDef(flagSymbol, Literal(Constants.Constant(false))) + val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs, nullableFor(x.symbol))) + Thicket(containerTree, flag, slowPath) + } + } - /** Create non-threadsafe lazy accessor for not-nullable types equivalent to such code - * def methodSymbol() = { - * if (target eq null) { - * target = rhs - * target - * } else target - * } - */ - def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree)(implicit ctx: Context) = { - val cond = ref(target).select(nme.eq).appliedTo(Literal(Constant(null))) - val exp = ref(target) - val setTarget = exp.becomes(rhs) - val init = Block(List(setTarget), exp) - If(cond, init, exp) + /** Create a threadsafe lazy accessor equivalent to such code + * ``` + * def methodSymbol(): Int = { + * val result: Int = 0 + * val retry: Boolean = true + * var flag: Long = 0L + * while retry do { + * flag = dotty.runtime.LazyVals.get(this, $claz.$OFFSET) + * dotty.runtime.LazyVals.STATE(flag, 0) match { + * case 0 => + * if dotty.runtime.LazyVals.CAS(this, $claz.$OFFSET, flag, 1, $ord) { + * try {result = rhs} catch { + * case x: Throwable => + * dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 0, $ord) + * throw x + * } + * $target = result + * dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 3, $ord) + * retry = false + * } + * case 1 => + * dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord) + * case 2 => + * dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord) + * case 3 => + * retry = false + * result = $target + * } + * } + * nullable = null + * result + * } + * ``` + */ + def mkThreadSafeDef(methodSymbol: TermSymbol, + claz: ClassSymbol, + ord: Int, + target: Symbol, + rhs: Tree, + tp: Types.Type, + offset: Tree, + getFlag: Tree, + stateMask: Tree, + casFlag: Tree, + setFlagState: Tree, + waitOnLock: Tree, + nullables: List[Symbol])(implicit ctx: Context) = { + val initState = Literal(Constants.Constant(0)) + val computeState = Literal(Constants.Constant(1)) + val notifyState = Literal(Constants.Constant(2)) + val computedState = Literal(Constants.Constant(3)) + val flagSymbol = ctx.newSymbol(methodSymbol, lazyNme.flag, containerFlags, defn.LongType) + val flagDef = ValDef(flagSymbol, Literal(Constant(0L))) + + val thiz = This(claz)(ctx.fresh.setOwner(claz)) + + val resultSymbol = ctx.newSymbol(methodSymbol, lazyNme.result, containerFlags, tp) + val resultDef = ValDef(resultSymbol, defaultValue(tp)) + + val retrySymbol = ctx.newSymbol(methodSymbol, lazyNme.retry, containerFlags, defn.BooleanType) + val retryDef = ValDef(retrySymbol, Literal(Constants.Constant(true))) + + val whileCond = ref(retrySymbol) + + val compute = { + val handlerSymbol = ctx.newSymbol(methodSymbol, nme.ANON_FUN, Flags.Synthetic, + MethodType(List(nme.x_1), List(defn.ThrowableType), defn.IntType)) + val caseSymbol = ctx.newSymbol(methodSymbol, nme.DEFAULT_EXCEPTION_NAME, Flags.Synthetic, defn.ThrowableType) + val triggerRetry = setFlagState.appliedTo(thiz, offset, initState, Literal(Constant(ord))) + val complete = setFlagState.appliedTo(thiz, offset, computedState, Literal(Constant(ord))) + + val handler = CaseDef(Bind(caseSymbol, ref(caseSymbol)), EmptyTree, + Block(List(triggerRetry), Throw(ref(caseSymbol)) + )) + + val compute = ref(resultSymbol).becomes(rhs) + val tr = Try(compute, List(handler), EmptyTree) + val assign = ref(target).becomes(ref(resultSymbol)) + val noRetry = ref(retrySymbol).becomes(Literal(Constants.Constant(false))) + val body = If(casFlag.appliedTo(thiz, offset, ref(flagSymbol), computeState, Literal(Constant(ord))), + Block(tr :: assign :: complete :: noRetry :: Nil, Literal(Constant(()))), + Literal(Constant(()))) + + CaseDef(initState, EmptyTree, body) } - def transformMemberDefNonVolatile(x: ValOrDefDef)(implicit ctx: Context) = { - val claz = x.symbol.owner.asClass - val tpe = x.tpe.widen.resultType.widen - assert(!(x.symbol is Flags.Mutable)) - val containerName = LazyLocalName.fresh(x.name.asTermName) - val containerSymbol = ctx.newSymbol(claz, containerName, - x.symbol.flags &~ containerFlagsMask | containerFlags | Flags.Private, - tpe, coord = x.symbol.coord - ).enteredAfter(this) - - val containerTree = ValDef(containerSymbol, defaultValue(tpe)) - if (x.tpe.isNotNull && tpe <:< defn.ObjectType) { // can use 'null' value instead of flag - val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs)) - Thicket(containerTree, slowPath) - } - else { - val flagName = LazyBitMapName.fresh(x.name.asTermName) - val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this) - val flag = ValDef(flagSymbol, Literal(Constants.Constant(false))) - val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs)) - Thicket(containerTree, flag, slowPath) - } + val waitFirst = { + val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord))) + CaseDef(computeState, EmptyTree, wait) } - /** Create a threadsafe lazy accessor equivalent to such code - * - * def methodSymbol(): Int = { - * val result: Int = 0 - * val retry: Boolean = true - * var flag: Long = 0L - * while retry do { - * flag = dotty.runtime.LazyVals.get(this, $claz.$OFFSET) - * dotty.runtime.LazyVals.STATE(flag, 0) match { - * case 0 => - * if dotty.runtime.LazyVals.CAS(this, $claz.$OFFSET, flag, 1, $ord) { - * try {result = rhs} catch { - * case x: Throwable => - * dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 0, $ord) - * throw x - * } - * $target = result - * dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 3, $ord) - * retry = false - * } - * case 1 => - * dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord) - * case 2 => - * dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord) - * case 3 => - * retry = false - * result = $target - * } - * } - * result - * } - */ - def mkThreadSafeDef(methodSymbol: TermSymbol, claz: ClassSymbol, ord: Int, target: Symbol, rhs: Tree, tp: Types.Type, offset: Tree, getFlag: Tree, stateMask: Tree, casFlag: Tree, setFlagState: Tree, waitOnLock: Tree)(implicit ctx: Context) = { - val initState = Literal(Constants.Constant(0)) - val computeState = Literal(Constants.Constant(1)) - val notifyState = Literal(Constants.Constant(2)) - val computedState = Literal(Constants.Constant(3)) - val flagSymbol = ctx.newSymbol(methodSymbol, lazyNme.flag, containerFlags, defn.LongType) - val flagDef = ValDef(flagSymbol, Literal(Constant(0L))) - - val thiz = This(claz)(ctx.fresh.setOwner(claz)) - - val resultSymbol = ctx.newSymbol(methodSymbol, lazyNme.result, containerFlags, tp) - val resultDef = ValDef(resultSymbol, defaultValue(tp)) - - val retrySymbol = ctx.newSymbol(methodSymbol, lazyNme.retry, containerFlags, defn.BooleanType) - val retryDef = ValDef(retrySymbol, Literal(Constants.Constant(true))) - - val whileCond = ref(retrySymbol) - - val compute = { - val handlerSymbol = ctx.newSymbol(methodSymbol, nme.ANON_FUN, Flags.Synthetic, - MethodType(List(nme.x_1), List(defn.ThrowableType), defn.IntType)) - val caseSymbol = ctx.newSymbol(methodSymbol, nme.DEFAULT_EXCEPTION_NAME, Flags.Synthetic, defn.ThrowableType) - val triggerRetry = setFlagState.appliedTo(thiz, offset, initState, Literal(Constant(ord))) - val complete = setFlagState.appliedTo(thiz, offset, computedState, Literal(Constant(ord))) - - val handler = CaseDef(Bind(caseSymbol, ref(caseSymbol)), EmptyTree, - Block(List(triggerRetry), Throw(ref(caseSymbol)) - )) - - val compute = ref(resultSymbol).becomes(rhs) - val tr = Try(compute, List(handler), EmptyTree) - val assign = ref(target).becomes(ref(resultSymbol)) - val noRetry = ref(retrySymbol).becomes(Literal(Constants.Constant(false))) - val body = If(casFlag.appliedTo(thiz, offset, ref(flagSymbol), computeState, Literal(Constant(ord))), - Block(tr :: assign :: complete :: noRetry :: Nil, Literal(Constant(()))), - Literal(Constant(()))) - - CaseDef(initState, EmptyTree, body) - } + val waitSecond = { + val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord))) + CaseDef(notifyState, EmptyTree, wait) + } - val waitFirst = { - val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord))) - CaseDef(computeState, EmptyTree, wait) - } + val computed = { + val noRetry = ref(retrySymbol).becomes(Literal(Constants.Constant(false))) + val result = ref(resultSymbol).becomes(ref(target)) + val body = Block(noRetry :: result :: Nil, Literal(Constant(()))) + CaseDef(computedState, EmptyTree, body) + } - val waitSecond = { - val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord))) - CaseDef(notifyState, EmptyTree, wait) - } + val default = CaseDef(Underscore(defn.LongType), EmptyTree, Literal(Constant(()))) - val computed = { - val noRetry = ref(retrySymbol).becomes(Literal(Constants.Constant(false))) - val result = ref(resultSymbol).becomes(ref(target)) - val body = Block(noRetry :: result :: Nil, Literal(Constant(()))) - CaseDef(computedState, EmptyTree, body) - } + val cases = Match(stateMask.appliedTo(ref(flagSymbol), Literal(Constant(ord))), + List(compute, waitFirst, waitSecond, computed, default)) //todo: annotate with @switch - val default = CaseDef(Underscore(defn.LongType), EmptyTree, Literal(Constant(()))) + val whileBody = List(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)), cases) + val cycle = WhileDo(methodSymbol, whileCond, whileBody) + val setNullables = nullOut(nullables) + DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: setNullables, ref(resultSymbol))) + } - val cases = Match(stateMask.appliedTo(ref(flagSymbol), Literal(Constant(ord))), - List(compute, waitFirst, waitSecond, computed, default)) //todo: annotate with @switch + def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context) = { + assert(!(x.symbol is Flags.Mutable)) + + val tpe = x.tpe.widen.resultType.widen + val claz = x.symbol.owner.asClass + val thizClass = Literal(Constant(claz.info)) + val helperModule = ctx.requiredModule("dotty.runtime.LazyVals") + val getOffset = Select(ref(helperModule), lazyNme.RLazyVals.getOffset) + var offsetSymbol: TermSymbol = null + var flag: Tree = EmptyTree + var ord = 0 + + def offsetName(id: Int) = (StdNames.nme.LAZY_FIELD_OFFSET + (if(x.symbol.owner.is(Flags.Module)) "_m_" else "") + id.toString).toTermName + + // compute or create appropriate offsetSymol, bitmap and bits used by current ValDef + appendOffsetDefs.get(claz) match { + case Some(info) => + val flagsPerLong = (64 / dotty.runtime.LazyVals.BITS_PER_LAZY_VAL).toInt + info.ord += 1 + ord = info.ord % flagsPerLong + val id = info.ord / flagsPerLong + val offsetById = offsetName(id) + if (ord != 0) { // there are unused bits in already existing flag + offsetSymbol = claz.info.decl(offsetById) + .suchThat(sym => (sym is Flags.Synthetic) && sym.isTerm) + .symbol.asTerm + } else { // need to create a new flag + offsetSymbol = ctx.newSymbol(claz, offsetById, Flags.Synthetic, defn.LongType).enteredAfter(this) + offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot)) + val flagName = (StdNames.nme.BITMAP_PREFIX + id.toString).toTermName + val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this) + flag = ValDef(flagSymbol, Literal(Constants.Constant(0L))) + val offsetTree = ValDef(offsetSymbol, getOffset.appliedTo(thizClass, Literal(Constant(flagName.toString)))) + info.defs = offsetTree :: info.defs + } - val whileBody = List(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)), cases) - val cycle = WhileDo(methodSymbol, whileCond, whileBody) - DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: Nil, ref(resultSymbol))) + case None => + offsetSymbol = ctx.newSymbol(claz, offsetName(0), Flags.Synthetic, defn.LongType).enteredAfter(this) + offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot)) + val flagName = (StdNames.nme.BITMAP_PREFIX + "0").toTermName + val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this) + flag = ValDef(flagSymbol, Literal(Constants.Constant(0L))) + val offsetTree = ValDef(offsetSymbol, getOffset.appliedTo(thizClass, Literal(Constant(flagName.toString)))) + appendOffsetDefs += (claz -> new OffsetInfo(List(offsetTree), ord)) } - def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context) = { - assert(!(x.symbol is Flags.Mutable)) - - val tpe = x.tpe.widen.resultType.widen - val claz = x.symbol.owner.asClass - val thizClass = Literal(Constant(claz.info)) - val helperModule = ctx.requiredModule("dotty.runtime.LazyVals") - val getOffset = Select(ref(helperModule), lazyNme.RLazyVals.getOffset) - var offsetSymbol: TermSymbol = null - var flag: Tree = EmptyTree - var ord = 0 - - def offsetName(id: Int) = (StdNames.nme.LAZY_FIELD_OFFSET + (if(x.symbol.owner.is(Flags.Module)) "_m_" else "") + id.toString).toTermName - - // compute or create appropriate offsetSymol, bitmap and bits used by current ValDef - appendOffsetDefs.get(claz) match { - case Some(info) => - val flagsPerLong = (64 / dotty.runtime.LazyVals.BITS_PER_LAZY_VAL).toInt - info.ord += 1 - ord = info.ord % flagsPerLong - val id = info.ord / flagsPerLong - val offsetById = offsetName(id) - if (ord != 0) { // there are unused bits in already existing flag - offsetSymbol = claz.info.decl(offsetById) - .suchThat(sym => (sym is Flags.Synthetic) && sym.isTerm) - .symbol.asTerm - } else { // need to create a new flag - offsetSymbol = ctx.newSymbol(claz, offsetById, Flags.Synthetic, defn.LongType).enteredAfter(this) - offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot)) - val flagName = (StdNames.nme.BITMAP_PREFIX + id.toString).toTermName - val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this) - flag = ValDef(flagSymbol, Literal(Constants.Constant(0L))) - val offsetTree = ValDef(offsetSymbol, getOffset.appliedTo(thizClass, Literal(Constant(flagName.toString)))) - info.defs = offsetTree :: info.defs - } - - case None => - offsetSymbol = ctx.newSymbol(claz, offsetName(0), Flags.Synthetic, defn.LongType).enteredAfter(this) - offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot)) - val flagName = (StdNames.nme.BITMAP_PREFIX + "0").toTermName - val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this) - flag = ValDef(flagSymbol, Literal(Constants.Constant(0L))) - val offsetTree = ValDef(offsetSymbol, getOffset.appliedTo(thizClass, Literal(Constant(flagName.toString)))) - appendOffsetDefs += (claz -> new OffsetInfo(List(offsetTree), ord)) - } - - val containerName = LazyLocalName.fresh(x.name.asTermName) - val containerSymbol = ctx.newSymbol(claz, containerName, x.symbol.flags &~ containerFlagsMask | containerFlags, tpe, coord = x.symbol.coord).enteredAfter(this) + val containerName = LazyLocalName.fresh(x.name.asTermName) + val containerSymbol = ctx.newSymbol(claz, containerName, x.symbol.flags &~ containerFlagsMask | containerFlags, tpe, coord = x.symbol.coord).enteredAfter(this) - val containerTree = ValDef(containerSymbol, defaultValue(tpe)) + val containerTree = ValDef(containerSymbol, defaultValue(tpe)) - val offset = ref(offsetSymbol) - val getFlag = Select(ref(helperModule), lazyNme.RLazyVals.get) - val setFlag = Select(ref(helperModule), lazyNme.RLazyVals.setFlag) - val wait = Select(ref(helperModule), lazyNme.RLazyVals.wait4Notification) - val state = Select(ref(helperModule), lazyNme.RLazyVals.state) - val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas) + val offset = ref(offsetSymbol) + val getFlag = Select(ref(helperModule), lazyNme.RLazyVals.get) + val setFlag = Select(ref(helperModule), lazyNme.RLazyVals.setFlag) + val wait = Select(ref(helperModule), lazyNme.RLazyVals.wait4Notification) + val state = Select(ref(helperModule), lazyNme.RLazyVals.state) + val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas) + val nullables = nullableFor(x.symbol) - val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait) - if (flag eq EmptyTree) - Thicket(containerTree, accessor) - else Thicket(containerTree, flag, accessor) - } + val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait, nullables) + if (flag eq EmptyTree) + Thicket(containerTree, accessor) + else Thicket(containerTree, flag, accessor) + } } object LazyVals { diff --git a/tests/run/i1692.scala b/tests/run/i1692.scala new file mode 100644 index 000000000000..f70cd1b2ed16 --- /dev/null +++ b/tests/run/i1692.scala @@ -0,0 +1,141 @@ +class VCInt(val x: Int) extends AnyVal +class VCString(val x: String) extends AnyVal + +class LazyNullable(a: => Int) { + lazy val l0 = a // null out a + + private[this] val b = "B" + lazy val l1 = b // null out b + + private[this] val c = "C" + @volatile lazy val l2 = c // null out c + + private[this] val d = "D" + lazy val l3 = d + d // null out d (Scalac require single use?) +} + +object LazyNullable2 { + private[this] val a = "A" + lazy val l0 = a // null out a +} + +class LazyNotNullable { + private[this] val a = 'A'.toInt // not nullable type + lazy val l0 = a + + private[this] val b = new VCInt('B'.toInt) // not nullable type + lazy val l1 = b + + private[this] val c = new VCString("C") // should be nullable but is not?? + lazy val l2 = c + + private[this] lazy val d = "D" // not nullable because lazy + lazy val l3 = d + + private val e = "E" // not nullable because not private[this] + lazy val l4 = e + + private[this] val f = "F" // not nullable because used in mutiple lazy vals + lazy val l5 = f + lazy val l6 = f + + private[this] val g = "G" // not nullable because used outside a lazy val initializer + def foo = g + lazy val l7 = g + + private[this] val h = "H" // not nullable because field and lazy val not defined in the same class + class Inner { + lazy val l8 = h + } +} + +trait LazyTrait { + private val a = "A" + lazy val l0 = a +} + +class Foo(val x: String) + +class LazyNotNullable2(x: String) extends Foo(x) { + lazy val y = x // not nullable. Here x is super.x +} + + +object Test { + def main(args: Array[String]): Unit = { + nullableTests() + notNullableTests() + } + + def nullableTests() = { + val lz = new LazyNullable('A'.toInt) + + def assertNull(fieldName: String) = { + val value = readField(fieldName, lz) + assert(value == null, s"$fieldName was $value, null expected") + } + + assert(lz.l0 == 'A'.toInt) + assertNull("a") + + assert(lz.l1 == "B") + assertNull("b") + + assert(lz.l2 == "C") + assertNull("c") + + assert(lz.l3 == "DD") + assertNull("d") + + assert(LazyNullable2.l0 == "A") + assert(readField("a", LazyNullable2) == null) + } + + def notNullableTests() = { + val lz = new LazyNotNullable + + def assertNotNull(fieldName: String) = { + val value = readField(fieldName, lz) + assert(value != null, s"$fieldName was null") + } + + assert(lz.l0 == 'A'.toInt) + assertNotNull("a") + + assert(lz.l1 == new VCInt('B'.toInt)) + assertNotNull("b") + + assert(lz.l2 == new VCString("C")) + assertNotNull("c") + + assert(lz.l3 == "D") + + assert(lz.l4 == "E") + assertNotNull("e") + + assert(lz.l5 == "F") + assert(lz.l6 == "F") + assertNotNull("f") + + assert(lz.l7 == "G") + assertNotNull("g") + + val inner = new lz.Inner + assert(inner.l8 == "H") + assertNotNull("LazyNotNullable$$h") // fragile: test will break if compiler generated names change + + val fromTrait = new LazyTrait {} + assert(fromTrait.l0 == "A") + assert(readField("LazyTrait$$a", fromTrait) != null) // fragile: test will break if compiler generated names change + + val lz2 = new LazyNotNullable2("Hello") + assert(lz2.y == "Hello") + assert(lz2.x == "Hello") + } + + def readField(fieldName: String, target: Any): Any = { + val field = target.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(target) + } +}