From 04563beda2d86f1c5f40cb8308a124476a83fb5b Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 21 Jul 2022 11:24:19 +1000 Subject: [PATCH] Improve positions of async-transformed code Don't position the wrapping try / while trees at the position of the result expression anymore. --- .../nsc/transform/async/ExprBuilder.scala | 22 +- .../nsc/transform/async/TransformUtils.scala | 3 +- test/async/run/positions.scala | 16 ++ .../nsc/async/AnnotationDrivenAsyncTest.scala | 249 +++++++++++++----- 4 files changed, 217 insertions(+), 73 deletions(-) create mode 100644 test/async/run/positions.scala diff --git a/src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala b/src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala index fd2affe54268..0d2b1927acc1 100644 --- a/src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala +++ b/src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala @@ -32,7 +32,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { case ap @ Apply(i @ Ident(_), Nil) if isCaseLabel(i.symbol) || isMatchEndLabel(i.symbol) => currentTransformState.labelDefStates.get(i.symbol) match { case Some(state) => - Block(StateTransitionStyle.UpdateAndContinue.trees(state, new StateSet), typed(literalUnit)).setType(definitions.UnitTpe) + Block(StateTransitionStyle.UpdateAndContinue.trees(state, new StateSet), typedCurrentPos(literalUnit)).setType(definitions.UnitTpe) case None => ap } case tree => tree @@ -60,7 +60,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { val stats1 = mutable.ListBuffer[Tree]() def addNullAssigments(syms: Iterator[Symbol]): Unit = { for (fieldSym <- syms) { - stats1 += typed(Assign(currentTransformState.memberRef(fieldSym), gen.mkZero(fieldSym.info))) + stats1 += typedCurrentPos(Assign(currentTransformState.memberRef(fieldSym), gen.mkZero(fieldSym.info))) } } // Add pre-state null assigments at the beginning. @@ -148,7 +148,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { } allNextStates += nextState - stats += typed(Return(literalUnit).setSymbol(currentTransformState.applySym)) + stats += typedCurrentPos(Return(literalUnit).setSymbol(currentTransformState.applySym)) } if (state == StateAssigner.Terminal) { // noop @@ -462,7 +462,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { val asyncStatesInit = asyncStates.init // drop the terminal state which has no code. val throww = Throw(Apply(Select(New(Ident(IllegalStateExceptionClass)), IllegalStateExceptionClass_NEW_String), List(gen.mkMethodCall(currentRun.runDefinitions.String_valueOf_Int, stateMemberRef :: Nil)))) val body = - typed(Match(stateMemberRef, + typedBasePos(Match(stateMemberRef, asyncStatesInit.map(_.mkHandlerCaseForState) ++ List(CaseDef(Ident(nme.WILDCARD), EmptyTree, throww)))) @@ -480,7 +480,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { ) ), EmptyTree) } - typed(LabelDef(transformState.whileLabel, Nil, Block(stateMatch :: Nil, Apply(Ident(transformState.whileLabel), Nil)))) + typedBasePos (LabelDef(transformState.whileLabel, Nil, Block(stateMatch :: Nil, Apply(Ident(transformState.whileLabel), Nil)))) } private def compactStates = true @@ -557,7 +557,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { } else { val temp = awaitableResult.symbol.newTermSymbol(nme.trGetResult).setInfo(definitions.ObjectTpe) val tempVd = ValDef(temp, gen.mkMethodCall(currentTransformState.memberRef(currentTransformState.stateTryGet), tryyReference :: Nil)) - typed(Block( + typedCurrentPos(Block( tempVd :: Nil, If(Apply(gen.mkAttributedSelect(currentTransformState.stateMachineRef(), definitions.Object_eq), gen.mkAttributedIdent(temp) :: Nil), Return(literalUnit), @@ -571,7 +571,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { // Comlete the Promise in the `result` field with the final successful result of this async block. private def completeSuccess(expr: Tree): Tree = { deriveTree(expr, definitions.UnitTpe) { expr => - typed(Apply(currentTransformState.memberRef(currentTransformState.stateCompleteSuccess), expr :: Nil)) + typedCurrentPos(Apply(currentTransformState.memberRef(currentTransformState.stateCompleteSuccess), expr :: Nil)) } } @@ -581,7 +581,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { protected def mkStateTree(nextState: Int): Tree = { val transformState = currentTransformState val callSetter = Apply(transformState.memberRef(transformState.stateSetter), Literal(Constant(nextState)) :: Nil) - typed(callSetter.updateAttachment(StateTransitionTree)) + typedCurrentPos(callSetter.updateAttachment(StateTransitionTree)) } } @@ -625,9 +625,9 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { If(Apply(null_ne, Ident(transformState.applyTrParam) :: Nil), Apply(Ident(transformState.whileLabel), Nil), Block(toStats(callOnComplete(gen.mkAttributedIdent(tempAwaitableSym))), Return(literalUnit).setSymbol(transformState.applySym))) - typed(initAwaitableTemp) :: typed(initTempCompleted) :: mkStateTree(nextState) :: typed(ifTree) :: Nil + typedCurrentPos(initAwaitableTemp) :: typedCurrentPos(initTempCompleted) :: mkStateTree(nextState) :: typedCurrentPos(ifTree) :: Nil } else { - mkStateTree(nextState) :: toStats(typed(callOnComplete(awaitable))) ::: typed(Return(literalUnit)) :: Nil + mkStateTree(nextState) :: toStats(typedCurrentPos(callOnComplete(awaitable))) ::: typedCurrentPos(Return(literalUnit)) :: Nil } } } @@ -636,7 +636,7 @@ trait ExprBuilder extends TransformUtils with AsyncAnalysis { case object UpdateAndContinue extends StateTransitionStyle { def trees(nextState: Int, stateSet: StateSet): List[Tree] = { stateSet += nextState - List(mkStateTree(nextState), typed(Apply(Ident(currentTransformState.whileLabel), Nil))) + List(mkStateTree(nextState), typedCurrentPos(Apply(Ident(currentTransformState.whileLabel), Nil))) } } } diff --git a/src/compiler/scala/tools/nsc/transform/async/TransformUtils.scala b/src/compiler/scala/tools/nsc/transform/async/TransformUtils.scala index 70cc6e317171..be14f1d0e017 100644 --- a/src/compiler/scala/tools/nsc/transform/async/TransformUtils.scala +++ b/src/compiler/scala/tools/nsc/transform/async/TransformUtils.scala @@ -24,7 +24,8 @@ private[async] trait TransformUtils extends AsyncTransformStates { private[async] val asyncNames: AsyncNames[global.type] def typedPos(pos: Position)(tree: Tree): Tree = currentTransformState.localTyper.typedPos(pos)(tree: Tree) - def typed(tree: Tree): Tree = typedPos(currentTransformState.currentPos)(tree) + def typedCurrentPos(tree: Tree): Tree = typedPos(currentTransformState.currentPos)(tree) + def typedBasePos(tree: Tree): Tree = typedPos(currentTransformState.applySym.pos)(tree) lazy val IllegalStateExceptionClass: Symbol = rootMirror.staticClass("java.lang.IllegalStateException") lazy val IllegalStateExceptionClass_NEW_String: Symbol = IllegalStateExceptionClass.info.decl(nme.CONSTRUCTOR).suchThat( diff --git a/test/async/run/positions.scala b/test/async/run/positions.scala new file mode 100644 index 000000000000..900697307d0d --- /dev/null +++ b/test/async/run/positions.scala @@ -0,0 +1,16 @@ +// scalac: -Xasync -Xprint:parser,typer,async -Xprint-pos -Yrangepos + +import scala.tools.partest.async.OptionAwait._ +import org.junit.Assert._ + +object Test { + def main(args: Array[String]): Unit = { + testBasic() + } + + private def testBasic() = optionally { + val x = value(Some(1)) + val y = value(Some(2)) + x + y + } +} \ No newline at end of file diff --git a/test/junit/scala/tools/nsc/async/AnnotationDrivenAsyncTest.scala b/test/junit/scala/tools/nsc/async/AnnotationDrivenAsyncTest.scala index b179dd2d2da5..cda0f0f5e35b 100644 --- a/test/junit/scala/tools/nsc/async/AnnotationDrivenAsyncTest.scala +++ b/test/junit/scala/tools/nsc/async/AnnotationDrivenAsyncTest.scala @@ -9,6 +9,7 @@ import org.junit.Assert.assertEquals import org.junit.{Assert, Ignore, Test} import scala.annotation.{StaticAnnotation, nowarn, unused} +import scala.collection.mutable import scala.concurrent.duration.Duration import scala.reflect.internal.util.Position import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader @@ -367,6 +368,98 @@ class AnnotationDrivenAsyncTest { assertEquals(classOf[Array[String]], result.getClass) } + @Test + def testPositions(): Unit = { + val code = + """ + |import scala.tools.nsc.async.{autoawait, customAsync} + |object Test { + | @autoawait def id(a: Int) = a + | @customAsync def test = { + | val x = id(1) + | val y = id(2) + | x + y + | } + |}""".stripMargin + + val result = compile(code) + + import result.global._ + // settings.Xprintpos.value = true // enable to help debugging + val methpdParseTree = result.parseTree.find { case dt: DefTree => dt.name.string_==("test") case _ => false } + methpdParseTree.get match { + case DefDef(_, _, _, _, _, Block(stats, expr)) => + val parseTreeStats: List[Tree] = (expr :: stats) + val fsmTree = result.fsmTree + val posMap = mutable.LinkedHashMap[Tree, mutable.Buffer[Tree]]() + val parentMap = mutable.LinkedHashMap[Tree, Tree]() + def collectParents(t: Tree): Unit = { + for (child <- t.children) { + parentMap(child) = t + collectParents(child) + } + } + def isAncestor(child: Tree, parent: Tree): Boolean = { + parentMap.get(child) match { + case None => false + case Some(p) => parent == p || isAncestor(p, parent) + } + } + collectParents(fsmTree.get) + + for { + parseTreeStat <- parseTreeStats + pos = parseTreeStat.pos + tree <- fsmTree.get + } { + if (pos.includes(tree.pos)) { + posMap.get(parseTreeStat) match { + case Some(existing) => + if (!existing.exists(t => isAncestor(tree, t))) { + val (retained, discarded) = existing.toList.partition(t => isAncestor(t, tree) || !isAncestor(tree, t)) + existing.clear() + existing ++= retained + existing += tree + } + case None => + posMap(parseTreeStat) = mutable.ListBuffer(tree) + } + } + } + + val incorrectlyContainedTryOrWhileLoop = posMap.values.flatMap(_.collect { case t: Try => t; case ld: LabelDef if ld.name.containsName(nme.WHILE_PREFIX) => ld}) + assert(incorrectlyContainedTryOrWhileLoop.isEmpty, incorrectlyContainedTryOrWhileLoop) + + def oneliner(s: String) = s.replace(System.lineSeparator(), "\\n") + val actual = posMap.toList.map { case (orig, corresponding) => s"${oneliner(orig.toString)}\n${"-" * 80}\n${corresponding.map(t => oneliner(t.toString)).mkString("\n")}"}.mkString("\n" * 3) + val expected = + """x.$plus(y) + |-------------------------------------------------------------------------------- + |self.completeSuccess(scala.Int.box(self.x.+(y))) + |return () + | + | + |val x = id(1) + |-------------------------------------------------------------------------------- + |case 0 => {\n val awaitable$async: scala.tools.nsc.async.CustomFuture = scala.tools.nsc.async.CustomFuture._successful(scala.Int.box(Test.this.id(1)));\n tr = self.getCompleted(awaitable$async);\n self.state_=(1);\n if (null.!=(tr))\n while$()\n else\n {\n self.onComplete(awaitable$async);\n return ()\n }\n} + | val await$1: Object = {\n val tryGetResult$async: Object = self.tryGet(tr);\n if (self.eq(tryGetResult$async))\n return ()\n else\n tryGetResult$async.$asInstanceOf[Object]()\n} + |self.x = scala.Int.unbox(await$1) + | + | + |val y = id(2) + |-------------------------------------------------------------------------------- + |val awaitable$async: scala.tools.nsc.async.CustomFuture = scala.tools.nsc.async.CustomFuture._successful(scala.Int.box(Test.this.id(2))) + |tr = self.getCompleted(awaitable$async) + |self.state_=(2) + |if (null.!=(tr))\n while$()\nelse\n {\n self.onComplete(awaitable$async);\n return ()\n } + | val await$2: Object = {\n val tryGetResult$async: Object = self.tryGet(tr);\n if (self.eq(tryGetResult$async))\n return ()\n else\n tryGetResult$async.$asInstanceOf[Object]()\n} + |val y: Int = scala.Int.unbox(await$2)""".stripMargin + assertEquals( + expected, actual) + } + } + + // Handy to debug the compiler or to collect code coverage statistics in IntelliJ. @Test @Ignore @@ -389,76 +482,110 @@ class AnnotationDrivenAsyncTest { f } + abstract class CompileResult { + val global: Global + val tree: global.Tree + val parseTree: global.Tree + def run(): Any + def close(): Unit + def fsmTree: Option[global.Tree] = tree.find { case dd: global.DefDef => dd.symbol.name.containsName("fsm"); case _ => false } + } + def run(code: String, compileOnly: Boolean = false): Any = { + val compileResult = compile(code, compileOnly) + try + if (!compileOnly) compileResult.run() + finally { + compileResult.close() + } + } + + def compile(code: String, compileOnly: Boolean = false): CompileResult = { val out = createTempDir() - try { - val reporter = new StoreReporter(new Settings) { - override def doReport(pos: Position, msg: String, severity: Severity): Unit = - if (severity == INFO) println(msg) - else super.doReport(pos, msg, severity) - } - val settings = new Settings(println(_)) - settings.async.value = true - settings.outdir.value = out.getAbsolutePath - settings.embeddedDefaults(getClass.getClassLoader) - // settings.debug.value = true - // settings.uniqid.value = true - // settings.processArgumentString("-Xprint:typer,posterasure,async -nowarn") - // settings.log.value = List("async") + val reporter = new StoreReporter(new Settings) { + override def doReport(pos: Position, msg: String, severity: Severity): Unit = + if (severity == INFO) println(msg) + else super.doReport(pos, msg, severity) + } + val settings = new Settings(println(_)) + settings.async.value = true + settings.outdir.value = out.getAbsolutePath + settings.embeddedDefaults(getClass.getClassLoader) - // NOTE: edit ANFTransform.traceAsync to `= true` to get additional diagnostic tracing. + // settings.debug.value = true + // settings.uniqid.value = true + // settings.processArgumentString("-Xprint:typer,posterasure,async -nowarn") + // settings.log.value = List("async") - val isInSBT = !settings.classpath.isSetByUser - if (isInSBT) settings.usejavacp.value = true - val global = new Global(settings, reporter) { - self => + // NOTE: edit ANFTransform.traceAsync to `= true` to get additional diagnostic tracing. - @nowarn("cat=deprecation&msg=early initializers") - object late extends { - val global: self.type = self - } with AnnotationDrivenAsyncPlugin + val isInSBT = !settings.classpath.isSetByUser + if (isInSBT) settings.usejavacp.value = true + val g = new Global(settings, reporter) { + self => - override protected def loadPlugins(): List[Plugin] = late :: Nil - } - import global._ - - val run = new Run - val source = newSourceFile(code) - run.compileSources(source :: Nil) - if (compileOnly) return null - def showInfo(info: StoreReporter#Info): String = { - Position.formatMessage(info.pos, info.severity.toString.toLowerCase + " : " + info.msg, false) - } - Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasErrors) - Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasWarnings) - val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) - val cls = loader.loadClass("Test") - val result = try { - cls.getMethod("test").invoke(null) - } catch { - case ite: InvocationTargetException => throw ite.getCause - case _: NoSuchMethodException => - cls.getMethod("main", classOf[Array[String]]).invoke(null, null) + @nowarn("cat=deprecation&msg=early initializers") + object late extends { + val global: self.type = self + } with AnnotationDrivenAsyncPlugin + + override protected def loadPlugins(): List[Plugin] = late :: Nil + } + + import g._ + + val run = new Run + val source = newSourceFile(code) + run.compileSources(source :: Nil) + + def showInfo(info: StoreReporter#Info): String = { + Position.formatMessage(info.pos, info.severity.toString.toLowerCase + " : " + info.msg, false) + } + + Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasErrors) + Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasWarnings) + + val unit: CompilationUnit = run.units.next() + val parseTree0 = newUnitParser(unit).parse() + new CompileResult { + val global: g.type = g + + val tree = unit.body + override val parseTree: global.Tree = parseTree0 + + def run(): Any = { + try { + val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) + val cls = loader.loadClass("Test") + val result = try { + cls.getMethod("test").invoke(null) + } catch { + case ite: InvocationTargetException => throw ite.getCause + case _: NoSuchMethodException => + cls.getMethod("main", classOf[Array[String]]).invoke(null, null) + } + result match { + case t: scala.concurrent.Future[_] => + scala.concurrent.Await.result(t, Duration.Inf) + case cf: CustomFuture[_] => + cf._block + case cf: CompletableFuture[_] => + cf.get() + case value => value + } + } catch { + case ve: VerifyError => + val asm = out.listFiles().flatMap { file => + val asmp = AsmUtils.textify(AsmUtils.readClass(file.getAbsolutePath)) + asmp :: Nil + }.mkString("\n\n") + throw new AssertionError(asm, ve) + } } - result match { - case t: scala.concurrent.Future[_] => - scala.concurrent.Await.result(t, Duration.Inf) - case cf: CustomFuture[_] => - cf._block - case cf: CompletableFuture[_] => - cf.get() - case value => value + override def close(): Unit = { + scala.reflect.io.Path.apply(out).deleteRecursively() } - } catch { - case ve: VerifyError => - val asm = out.listFiles().flatMap { file => - val asmp = AsmUtils.textify(AsmUtils.readClass(file.getAbsolutePath)) - asmp :: Nil - }.mkString("\n\n") - throw new AssertionError(asm, ve) - } finally { - scala.reflect.io.Path.apply(out).deleteRecursively() } } }