Skip to content

WIP position synthetic async code at the start of the async block, not last expr pos #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: 2.13.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
case ap @ Apply(i @ Ident(_), Nil) if isCaseLabel(i.symbol) || isMatchEndLabel(i.symbol) =>
currentTransformState.labelDefStates.get(i.symbol) match {
case Some(state) =>
Block(StateTransitionStyle.UpdateAndContinue.trees(state, new StateSet), typed(literalUnit)).setType(definitions.UnitTpe)
Block(StateTransitionStyle.UpdateAndContinue.trees(state, new StateSet), typedCurrentPos(literalUnit)).setType(definitions.UnitTpe)
case None => ap
}
case tree => tree
Expand Down Expand Up @@ -60,7 +60,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
val stats1 = mutable.ListBuffer[Tree]()
def addNullAssigments(syms: Iterator[Symbol]): Unit = {
for (fieldSym <- syms) {
stats1 += typed(Assign(currentTransformState.memberRef(fieldSym), gen.mkZero(fieldSym.info)))
stats1 += typedCurrentPos(Assign(currentTransformState.memberRef(fieldSym), gen.mkZero(fieldSym.info)))
}
}
// Add pre-state null assigments at the beginning.
Expand Down Expand Up @@ -148,7 +148,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
}

allNextStates += nextState
stats += typed(Return(literalUnit).setSymbol(currentTransformState.applySym))
stats += typedCurrentPos(Return(literalUnit).setSymbol(currentTransformState.applySym))
}
if (state == StateAssigner.Terminal) {
// noop
Expand Down Expand Up @@ -462,7 +462,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
val asyncStatesInit = asyncStates.init // drop the terminal state which has no code.
val throww = Throw(Apply(Select(New(Ident(IllegalStateExceptionClass)), IllegalStateExceptionClass_NEW_String), List(gen.mkMethodCall(currentRun.runDefinitions.String_valueOf_Int, stateMemberRef :: Nil))))
val body =
typed(Match(stateMemberRef,
typedBasePos(Match(stateMemberRef,
asyncStatesInit.map(_.mkHandlerCaseForState) ++
List(CaseDef(Ident(nme.WILDCARD), EmptyTree,
throww))))
Expand All @@ -480,7 +480,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
)
), EmptyTree)
}
typed(LabelDef(transformState.whileLabel, Nil, Block(stateMatch :: Nil, Apply(Ident(transformState.whileLabel), Nil))))
typedBasePos (LabelDef(transformState.whileLabel, Nil, Block(stateMatch :: Nil, Apply(Ident(transformState.whileLabel), Nil))))
}

private def compactStates = true
Expand Down Expand Up @@ -557,7 +557,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
} else {
val temp = awaitableResult.symbol.newTermSymbol(nme.trGetResult).setInfo(definitions.ObjectTpe)
val tempVd = ValDef(temp, gen.mkMethodCall(currentTransformState.memberRef(currentTransformState.stateTryGet), tryyReference :: Nil))
typed(Block(
typedCurrentPos(Block(
tempVd :: Nil,
If(Apply(gen.mkAttributedSelect(currentTransformState.stateMachineRef(), definitions.Object_eq), gen.mkAttributedIdent(temp) :: Nil),
Return(literalUnit),
Expand All @@ -571,7 +571,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
// Comlete the Promise in the `result` field with the final successful result of this async block.
private def completeSuccess(expr: Tree): Tree = {
deriveTree(expr, definitions.UnitTpe) { expr =>
typed(Apply(currentTransformState.memberRef(currentTransformState.stateCompleteSuccess), expr :: Nil))
typedCurrentPos(Apply(currentTransformState.memberRef(currentTransformState.stateCompleteSuccess), expr :: Nil))
}
}

Expand All @@ -581,7 +581,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
protected def mkStateTree(nextState: Int): Tree = {
val transformState = currentTransformState
val callSetter = Apply(transformState.memberRef(transformState.stateSetter), Literal(Constant(nextState)) :: Nil)
typed(callSetter.updateAttachment(StateTransitionTree))
typedCurrentPos(callSetter.updateAttachment(StateTransitionTree))
}
}

Expand Down Expand Up @@ -625,9 +625,9 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
If(Apply(null_ne, Ident(transformState.applyTrParam) :: Nil),
Apply(Ident(transformState.whileLabel), Nil),
Block(toStats(callOnComplete(gen.mkAttributedIdent(tempAwaitableSym))), Return(literalUnit).setSymbol(transformState.applySym)))
typed(initAwaitableTemp) :: typed(initTempCompleted) :: mkStateTree(nextState) :: typed(ifTree) :: Nil
typedCurrentPos(initAwaitableTemp) :: typedCurrentPos(initTempCompleted) :: mkStateTree(nextState) :: typedCurrentPos(ifTree) :: Nil
} else {
mkStateTree(nextState) :: toStats(typed(callOnComplete(awaitable))) ::: typed(Return(literalUnit)) :: Nil
mkStateTree(nextState) :: toStats(typedCurrentPos(callOnComplete(awaitable))) ::: typedCurrentPos(Return(literalUnit)) :: Nil
}
}
}
Expand All @@ -636,7 +636,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis {
case object UpdateAndContinue extends StateTransitionStyle {
def trees(nextState: Int, stateSet: StateSet): List[Tree] = {
stateSet += nextState
List(mkStateTree(nextState), typed(Apply(Ident(currentTransformState.whileLabel), Nil)))
List(mkStateTree(nextState), typedCurrentPos(Apply(Ident(currentTransformState.whileLabel), Nil)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ private[async] trait TransformUtils extends AsyncTransformStates {
private[async] val asyncNames: AsyncNames[global.type]

def typedPos(pos: Position)(tree: Tree): Tree = currentTransformState.localTyper.typedPos(pos)(tree: Tree)
def typed(tree: Tree): Tree = typedPos(currentTransformState.currentPos)(tree)
def typedCurrentPos(tree: Tree): Tree = typedPos(currentTransformState.currentPos)(tree)
def typedBasePos(tree: Tree): Tree = typedPos(currentTransformState.applySym.pos)(tree)

lazy val IllegalStateExceptionClass: Symbol = rootMirror.staticClass("java.lang.IllegalStateException")
lazy val IllegalStateExceptionClass_NEW_String: Symbol = IllegalStateExceptionClass.info.decl(nme.CONSTRUCTOR).suchThat(
Expand Down
16 changes: 16 additions & 0 deletions test/async/run/positions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// scalac: -Xasync -Xprint:parser,typer,async -Xprint-pos -Yrangepos

import scala.tools.partest.async.OptionAwait._
import org.junit.Assert._

object Test {
def main(args: Array[String]): Unit = {
testBasic()
}

private def testBasic() = optionally {
val x = value(Some(1))
val y = value(Some(2))
x + y
}
}
249 changes: 188 additions & 61 deletions test/junit/scala/tools/nsc/async/AnnotationDrivenAsyncTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.junit.Assert.assertEquals
import org.junit.{Assert, Ignore, Test}

import scala.annotation.{StaticAnnotation, nowarn, unused}
import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.reflect.internal.util.Position
import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
Expand Down Expand Up @@ -367,6 +368,98 @@ class AnnotationDrivenAsyncTest {
assertEquals(classOf[Array[String]], result.getClass)
}

@Test
def testPositions(): Unit = {
val code =
"""
|import scala.tools.nsc.async.{autoawait, customAsync}
|object Test {
| @autoawait def id(a: Int) = a
| @customAsync def test = {
| val x = id(1)
| val y = id(2)
| x + y
| }
|}""".stripMargin

val result = compile(code)

import result.global._
// settings.Xprintpos.value = true // enable to help debugging
val methpdParseTree = result.parseTree.find { case dt: DefTree => dt.name.string_==("test") case _ => false }
methpdParseTree.get match {
case DefDef(_, _, _, _, _, Block(stats, expr)) =>
val parseTreeStats: List[Tree] = (expr :: stats)
val fsmTree = result.fsmTree
val posMap = mutable.LinkedHashMap[Tree, mutable.Buffer[Tree]]()
val parentMap = mutable.LinkedHashMap[Tree, Tree]()
def collectParents(t: Tree): Unit = {
for (child <- t.children) {
parentMap(child) = t
collectParents(child)
}
}
def isAncestor(child: Tree, parent: Tree): Boolean = {
parentMap.get(child) match {
case None => false
case Some(p) => parent == p || isAncestor(p, parent)
}
}
collectParents(fsmTree.get)

for {
parseTreeStat <- parseTreeStats
pos = parseTreeStat.pos
tree <- fsmTree.get
} {
if (pos.includes(tree.pos)) {
posMap.get(parseTreeStat) match {
case Some(existing) =>
if (!existing.exists(t => isAncestor(tree, t))) {
val (retained, discarded) = existing.toList.partition(t => isAncestor(t, tree) || !isAncestor(tree, t))
existing.clear()
existing ++= retained
existing += tree
}
case None =>
posMap(parseTreeStat) = mutable.ListBuffer(tree)
}
}
}

val incorrectlyContainedTryOrWhileLoop = posMap.values.flatMap(_.collect { case t: Try => t; case ld: LabelDef if ld.name.containsName(nme.WHILE_PREFIX) => ld})
assert(incorrectlyContainedTryOrWhileLoop.isEmpty, incorrectlyContainedTryOrWhileLoop)

def oneliner(s: String) = s.replace(System.lineSeparator(), "\\n")
val actual = posMap.toList.map { case (orig, corresponding) => s"${oneliner(orig.toString)}\n${"-" * 80}\n${corresponding.map(t => oneliner(t.toString)).mkString("\n")}"}.mkString("\n" * 3)
val expected =
"""x.$plus(y)
|--------------------------------------------------------------------------------
|self.completeSuccess(scala.Int.box(self.x.+(y)))
|return ()
|
|
|val x = id(1)
|--------------------------------------------------------------------------------
|case 0 => {\n val awaitable$async: scala.tools.nsc.async.CustomFuture = scala.tools.nsc.async.CustomFuture._successful(scala.Int.box(Test.this.id(1)));\n tr = self.getCompleted(awaitable$async);\n self.state_=(1);\n if (null.!=(tr))\n while$()\n else\n {\n self.onComplete(awaitable$async);\n return ()\n }\n}
|<synthetic> val await$1: Object = {\n val tryGetResult$async: Object = self.tryGet(tr);\n if (self.eq(tryGetResult$async))\n return ()\n else\n tryGetResult$async.$asInstanceOf[Object]()\n}
|self.x = scala.Int.unbox(await$1)
|
|
|val y = id(2)
|--------------------------------------------------------------------------------
|val awaitable$async: scala.tools.nsc.async.CustomFuture = scala.tools.nsc.async.CustomFuture._successful(scala.Int.box(Test.this.id(2)))
|tr = self.getCompleted(awaitable$async)
|self.state_=(2)
|if (null.!=(tr))\n while$()\nelse\n {\n self.onComplete(awaitable$async);\n return ()\n }
|<synthetic> val await$2: Object = {\n val tryGetResult$async: Object = self.tryGet(tr);\n if (self.eq(tryGetResult$async))\n return ()\n else\n tryGetResult$async.$asInstanceOf[Object]()\n}
|val y: Int = scala.Int.unbox(await$2)""".stripMargin
assertEquals(
expected, actual)
}
}


// Handy to debug the compiler or to collect code coverage statistics in IntelliJ.
@Test
@Ignore
Expand All @@ -389,76 +482,110 @@ class AnnotationDrivenAsyncTest {
f
}

abstract class CompileResult {
val global: Global
val tree: global.Tree
val parseTree: global.Tree
def run(): Any
def close(): Unit
def fsmTree: Option[global.Tree] = tree.find { case dd: global.DefDef => dd.symbol.name.containsName("fsm"); case _ => false }
}

def run(code: String, compileOnly: Boolean = false): Any = {
val compileResult = compile(code, compileOnly)
try
if (!compileOnly) compileResult.run()
finally {
compileResult.close()
}
}

def compile(code: String, compileOnly: Boolean = false): CompileResult = {
val out = createTempDir()
try {
val reporter = new StoreReporter(new Settings) {
override def doReport(pos: Position, msg: String, severity: Severity): Unit =
if (severity == INFO) println(msg)
else super.doReport(pos, msg, severity)
}
val settings = new Settings(println(_))
settings.async.value = true
settings.outdir.value = out.getAbsolutePath
settings.embeddedDefaults(getClass.getClassLoader)

// settings.debug.value = true
// settings.uniqid.value = true
// settings.processArgumentString("-Xprint:typer,posterasure,async -nowarn")
// settings.log.value = List("async")
val reporter = new StoreReporter(new Settings) {
override def doReport(pos: Position, msg: String, severity: Severity): Unit =
if (severity == INFO) println(msg)
else super.doReport(pos, msg, severity)
}
val settings = new Settings(println(_))
settings.async.value = true
settings.outdir.value = out.getAbsolutePath
settings.embeddedDefaults(getClass.getClassLoader)

// NOTE: edit ANFTransform.traceAsync to `= true` to get additional diagnostic tracing.
// settings.debug.value = true
// settings.uniqid.value = true
// settings.processArgumentString("-Xprint:typer,posterasure,async -nowarn")
// settings.log.value = List("async")

val isInSBT = !settings.classpath.isSetByUser
if (isInSBT) settings.usejavacp.value = true
val global = new Global(settings, reporter) {
self =>
// NOTE: edit ANFTransform.traceAsync to `= true` to get additional diagnostic tracing.

@nowarn("cat=deprecation&msg=early initializers")
object late extends {
val global: self.type = self
} with AnnotationDrivenAsyncPlugin
val isInSBT = !settings.classpath.isSetByUser
if (isInSBT) settings.usejavacp.value = true
val g = new Global(settings, reporter) {
self =>

override protected def loadPlugins(): List[Plugin] = late :: Nil
}
import global._

val run = new Run
val source = newSourceFile(code)
run.compileSources(source :: Nil)
if (compileOnly) return null
def showInfo(info: StoreReporter#Info): String = {
Position.formatMessage(info.pos, info.severity.toString.toLowerCase + " : " + info.msg, false)
}
Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasErrors)
Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasWarnings)
val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
val cls = loader.loadClass("Test")
val result = try {
cls.getMethod("test").invoke(null)
} catch {
case ite: InvocationTargetException => throw ite.getCause
case _: NoSuchMethodException =>
cls.getMethod("main", classOf[Array[String]]).invoke(null, null)
@nowarn("cat=deprecation&msg=early initializers")
object late extends {
val global: self.type = self
} with AnnotationDrivenAsyncPlugin

override protected def loadPlugins(): List[Plugin] = late :: Nil
}

import g._

val run = new Run
val source = newSourceFile(code)
run.compileSources(source :: Nil)

def showInfo(info: StoreReporter#Info): String = {
Position.formatMessage(info.pos, info.severity.toString.toLowerCase + " : " + info.msg, false)
}

Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasErrors)
Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasWarnings)

val unit: CompilationUnit = run.units.next()
val parseTree0 = newUnitParser(unit).parse()
new CompileResult {
val global: g.type = g

val tree = unit.body
override val parseTree: global.Tree = parseTree0

def run(): Any = {
try {
val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
val cls = loader.loadClass("Test")
val result = try {
cls.getMethod("test").invoke(null)
} catch {
case ite: InvocationTargetException => throw ite.getCause
case _: NoSuchMethodException =>
cls.getMethod("main", classOf[Array[String]]).invoke(null, null)
}
result match {
case t: scala.concurrent.Future[_] =>
scala.concurrent.Await.result(t, Duration.Inf)
case cf: CustomFuture[_] =>
cf._block
case cf: CompletableFuture[_] =>
cf.get()
case value => value
}
} catch {
case ve: VerifyError =>
val asm = out.listFiles().flatMap { file =>
val asmp = AsmUtils.textify(AsmUtils.readClass(file.getAbsolutePath))
asmp :: Nil
}.mkString("\n\n")
throw new AssertionError(asm, ve)
}
}
result match {
case t: scala.concurrent.Future[_] =>
scala.concurrent.Await.result(t, Duration.Inf)
case cf: CustomFuture[_] =>
cf._block
case cf: CompletableFuture[_] =>
cf.get()
case value => value
override def close(): Unit = {
scala.reflect.io.Path.apply(out).deleteRecursively()
}
} catch {
case ve: VerifyError =>
val asm = out.listFiles().flatMap { file =>
val asmp = AsmUtils.textify(AsmUtils.readClass(file.getAbsolutePath))
asmp :: Nil
}.mkString("\n\n")
throw new AssertionError(asm, ve)
} finally {
scala.reflect.io.Path.apply(out).deleteRecursively()
}
}
}
Expand Down