Skip to content

Commit 41ae707

Browse files
committed
Merge pull request #98 from retronym/ticket/73
Make `f(await(completedFuture))` execute `f` synchronously
2 parents 61b4c18 + 063492a commit 41ae707

File tree

9 files changed

+166
-88
lines changed

9 files changed

+166
-88
lines changed

src/main/scala/scala/async/internal/AsyncTransform.scala

+7-22
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,24 @@ trait AsyncTransform {
2424

2525
val anfTree = futureSystemOps.postAnfTransform(anfTree0)
2626

27-
val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(())))
28-
2927
val applyDefDefDummyBody: DefDef = {
3028
val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.tryType[Any]), EmptyTree)))
31-
DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(())))
29+
DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), literalUnit)
3230
}
3331

3432
// Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`.
3533
val stateMachine: ClassDef = {
3634
val body: List[Tree] = {
37-
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
35+
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial)))
3836
val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
3937
val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
4038

4139
val apply0DefDef: DefDef = {
4240
// We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
43-
// See SI-1247 for the the optimization that avoids creatio
44-
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
45-
}
46-
val extraValDef: ValDef = {
47-
// We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
48-
// See SI-1247 for the the optimization that avoids creatio
49-
ValDef(NoMods, newTermName("extra"), TypeTree(definitions.UnitTpe), Literal(Constant(())))
41+
// See SI-1247 for the the optimization that avoids creation.
42+
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil))
5043
}
51-
List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef, extraValDef)
44+
List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
5245
}
5346

5447
val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit])
@@ -90,8 +83,7 @@ trait AsyncTransform {
9083
val stateMachineSpliced: Tree = spliceMethodBodies(
9184
liftedFields,
9285
stateMachine,
93-
atMacroPos(asyncBlock.onCompleteHandler[T]),
94-
atMacroPos(asyncBlock.resumeFunTree[T].rhs)
86+
atMacroPos(asyncBlock.onCompleteHandler[T])
9587
)
9688

9789
def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
@@ -131,10 +123,9 @@ trait AsyncTransform {
131123
* @param liftables trees of definitions that are lifted to fields of the state machine class
132124
* @param tree `ClassDef` tree of the state machine class
133125
* @param applyBody tree of onComplete handler (`apply` method)
134-
* @param resumeBody RHS of definition tree of `resume` method
135126
* @return transformed `ClassDef` tree of the state machine class
136127
*/
137-
def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree, resumeBody: Tree): Tree = {
128+
def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree): Tree = {
138129
val liftedSyms = liftables.map(_.symbol).toSet
139130
val stateMachineClass = tree.symbol
140131
liftedSyms.foreach {
@@ -211,12 +202,6 @@ trait AsyncTransform {
211202
(ctx: analyzer.Context) =>
212203
val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx)
213204
typedTree
214-
215-
case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass =>
216-
(ctx: analyzer.Context) =>
217-
val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol)
218-
val res = fixup(dd, changed, ctx)
219-
res
220205
}
221206
result
222207
}

src/main/scala/scala/async/internal/ExprBuilder.scala

+71-54
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ trait ExprBuilder {
2727

2828
def nextStates: List[Int]
2929

30-
def mkHandlerCaseForState: CaseDef
30+
def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef
3131

3232
def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None
3333

@@ -51,8 +51,8 @@ trait ExprBuilder {
5151
def nextStates: List[Int] =
5252
List(nextState)
5353

54-
def mkHandlerCaseForState: CaseDef =
55-
mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup))
54+
def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef =
55+
mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup))
5656

5757
override val toString: String =
5858
s"AsyncState #$state, next = $nextState"
@@ -62,7 +62,7 @@ trait ExprBuilder {
6262
* a branch of an `if` or a `match`.
6363
*/
6464
final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState {
65-
override def mkHandlerCaseForState: CaseDef =
65+
override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef =
6666
mkHandlerCase(state, stats)
6767

6868
override val toString: String =
@@ -72,45 +72,54 @@ trait ExprBuilder {
7272
/** A sequence of statements that concludes with an `await` call. The `onComplete`
7373
* handler will unconditionally transition to `nextState`.
7474
*/
75-
final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, nextState: Int,
75+
final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, onCompleteState: Int, nextState: Int,
7676
val awaitable: Awaitable, symLookup: SymLookup)
7777
extends AsyncState {
7878

7979
def nextStates: List[Int] =
8080
List(nextState)
8181

82-
override def mkHandlerCaseForState: CaseDef = {
83-
val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr),
84-
Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree
85-
mkHandlerCase(state, stats :+ callOnComplete)
82+
override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = {
83+
val fun = This(tpnme.EMPTY)
84+
val callOnComplete = futureSystemOps.onComplete[Any, Unit](Expr[futureSystem.Fut[Any]](awaitable.expr),
85+
Expr[futureSystem.Tryy[Any] => Unit](fun), Expr[futureSystem.ExecContext](Ident(name.execContext))).tree
86+
val tryGetOrCallOnComplete =
87+
if (futureSystemOps.continueCompletedFutureOnSameThread)
88+
If(futureSystemOps.isCompleted(Expr[futureSystem.Fut[_]](awaitable.expr)).tree,
89+
Block(ifIsFailureTree[T](futureSystemOps.getCompleted[Any](Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) :: Nil, literalUnit),
90+
Block(callOnComplete :: Nil, Return(literalUnit)))
91+
else
92+
Block(callOnComplete :: Nil, Return(literalUnit))
93+
mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), tryGetOrCallOnComplete))
8694
}
8795

96+
private def tryGetTree(tryReference: => Tree) =
97+
Assign(
98+
Ident(awaitable.resultName),
99+
TypeApply(Select(futureSystemOps.tryyGet[Any](Expr[futureSystem.Tryy[Any]](tryReference)).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
100+
)
101+
102+
/* if (tr.isFailure)
103+
* result.complete(tr.asInstanceOf[Try[T]])
104+
* else {
105+
* <resultName> = tr.get.asInstanceOf[<resultType>]
106+
* <nextState>
107+
* <mkResumeApply>
108+
* }
109+
*/
110+
def ifIsFailureTree[T: WeakTypeTag](tryReference: => Tree) =
111+
If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](tryReference)).tree,
112+
Block(futureSystemOps.completeProm[T](
113+
Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
114+
Expr[futureSystem.Tryy[T]](
115+
TypeApply(Select(tryReference, newTermName("asInstanceOf")),
116+
List(TypeTree(futureSystemOps.tryType[T]))))).tree :: Nil,
117+
Return(literalUnit)),
118+
Block(List(tryGetTree(tryReference)), mkStateTree(nextState, symLookup))
119+
)
120+
88121
override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = {
89-
val tryGetTree =
90-
Assign(
91-
Ident(awaitable.resultName),
92-
TypeApply(Select(futureSystemOps.tryyGet[T](Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
93-
)
94-
95-
/* if (tr.isFailure)
96-
* result.complete(tr.asInstanceOf[Try[T]])
97-
* else {
98-
* <resultName> = tr.get.asInstanceOf[<resultType>]
99-
* <nextState>
100-
* <mkResumeApply>
101-
* }
102-
*/
103-
val ifIsFailureTree =
104-
If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree,
105-
futureSystemOps.completeProm[T](
106-
Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
107-
Expr[futureSystem.Tryy[T]](
108-
TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")),
109-
List(TypeTree(futureSystemOps.tryType[T]))))).tree,
110-
Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup))
111-
)
112-
113-
Some(mkHandlerCase(state, List(ifIsFailureTree)))
122+
Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree[T](Ident(symLookup.applyTrParam)))))
114123
}
115124

116125
override val toString: String =
@@ -146,9 +155,10 @@ trait ExprBuilder {
146155
}
147156

148157
def resultWithAwait(awaitable: Awaitable,
158+
onCompleteState: Int,
149159
nextState: Int): AsyncState = {
150160
val effectiveNextState = nextJumpState.getOrElse(nextState)
151-
new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup)
161+
new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup)
152162
}
153163

154164
def resultSimple(nextState: Int): AsyncState = {
@@ -157,7 +167,7 @@ trait ExprBuilder {
157167
}
158168

159169
def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
160-
def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup))
170+
def mkBranch(state: Int) = mkStateTree(state, symLookup)
161171
this += If(condTree, mkBranch(thenState), mkBranch(elseState))
162172
new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState))
163173
}
@@ -177,15 +187,15 @@ trait ExprBuilder {
177187
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
178188
case CaseDef(pat, guard, rhs) =>
179189
val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal)
180-
CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup)))
190+
CaseDef(pat, guard, Block(bindAssigns, mkStateTree(caseStates(num), symLookup)))
181191
}
182192
// 2. insert changed match tree at the end of the current state
183193
this += Match(scrutTree, newCases)
184194
new AsyncStateWithoutAwait(stats.toList, state, caseStates)
185195
}
186196

187197
def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = {
188-
this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup))
198+
this += mkStateTree(startLabelState, symLookup)
189199
new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState))
190200
}
191201

@@ -226,9 +236,10 @@ trait ExprBuilder {
226236
for (stat <- stats) stat match {
227237
// the val name = await(..) pattern
228238
case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
239+
val onCompleteState = nextState()
229240
val afterAwaitState = nextState()
230241
val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd)
231-
asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await
242+
asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await
232243
currState = afterAwaitState
233244
stateBuilder = new AsyncStateBuilder(currState, symLookup)
234245

@@ -296,8 +307,6 @@ trait ExprBuilder {
296307
def asyncStates: List[AsyncState]
297308

298309
def onCompleteHandler[T: WeakTypeTag]: Tree
299-
300-
def resumeFunTree[T: WeakTypeTag]: DefDef
301310
}
302311

303312
case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) {
@@ -330,13 +339,13 @@ trait ExprBuilder {
330339
val lastStateBody = Expr[T](lastState.body)
331340
val rhs = futureSystemOps.completeProm(
332341
Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryySuccess[T](lastStateBody))
333-
mkHandlerCase(lastState.state, rhs.tree)
342+
mkHandlerCase(lastState.state, Block(rhs.tree, Return(literalUnit)))
334343
}
335344
asyncStates.toList match {
336345
case s :: Nil =>
337346
List(caseForLastState)
338347
case _ =>
339-
val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState
348+
val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState[T]
340349
initCases :+ caseForLastState
341350
}
342351
}
@@ -362,18 +371,23 @@ trait ExprBuilder {
362371
* }
363372
* }
364373
*/
365-
def resumeFunTree[T: WeakTypeTag]: DefDef =
366-
DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass),
374+
private def resumeFunTree[T: WeakTypeTag]: Tree =
367375
Try(
368-
Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]),
376+
Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ),
369377
List(
370378
CaseDef(
371379
Bind(name.t, Ident(nme.WILDCARD)),
372380
Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), {
373381
val t = Expr[Throwable](Ident(name.t))
374-
futureSystemOps.completeProm[T](
382+
val complete = futureSystemOps.completeProm[T](
375383
Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree
376-
})), EmptyTree))
384+
Block(complete :: Nil, Return(literalUnit))
385+
})), EmptyTree)
386+
387+
def forever(t: Tree): Tree = {
388+
val labelName = name.fresh("while$")
389+
LabelDef(labelName, Nil, Block(t :: Nil, Apply(Ident(labelName), Nil)))
390+
}
377391

378392
/**
379393
* Builds a `match` expression used as an onComplete handler.
@@ -387,8 +401,12 @@ trait ExprBuilder {
387401
* resume()
388402
* }
389403
*/
390-
def onCompleteHandler[T: WeakTypeTag]: Tree =
391-
Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
404+
def onCompleteHandler[T: WeakTypeTag]: Tree = {
405+
val onCompletes = initStates.flatMap(_.mkOnCompleteHandler[T]).toList
406+
forever {
407+
Block(resumeFunTree :: Nil, literalUnit)
408+
}
409+
}
392410
}
393411
}
394412

@@ -399,9 +417,6 @@ trait ExprBuilder {
399417

400418
case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef)
401419

402-
private def mkResumeApply(symLookup: SymLookup) =
403-
Apply(symLookup.memberRef(name.resume), Nil)
404-
405420
private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree =
406421
Assign(symLookup.memberRef(name.state), Literal(Constant(nextState)))
407422

@@ -411,5 +426,7 @@ trait ExprBuilder {
411426
private def mkHandlerCase(num: Int, rhs: Tree): CaseDef =
412427
CaseDef(Literal(Constant(num)), EmptyTree, rhs)
413428

414-
private def literalUnit = Literal(Constant(()))
429+
def literalUnit = Literal(Constant(()))
430+
431+
def literalNull = Literal(Constant(null))
415432
}

src/main/scala/scala/async/internal/FutureSystem.scala

+15
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ trait FutureSystem {
4949
def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[Tryy[A] => U],
5050
execContext: Expr[ExecContext]): Expr[Unit]
5151

52+
def continueCompletedFutureOnSameThread = false
53+
def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] =
54+
throw new UnsupportedOperationException("isCompleted not supported by this FutureSystem")
55+
def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] =
56+
throw new UnsupportedOperationException("getCompleted not supported by this FutureSystem")
57+
5258
/** Complete a promise with a value */
5359
def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit]
5460

@@ -103,6 +109,15 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
103109
future.splice.onComplete(fun.splice)(execContext.splice)
104110
}
105111

112+
override def continueCompletedFutureOnSameThread: Boolean = true
113+
114+
override def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = reify {
115+
future.splice.isCompleted
116+
}
117+
override def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = reify {
118+
future.splice.value.get
119+
}
120+
106121
def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify {
107122
prom.splice.complete(value.splice)
108123
Expr[Unit](Literal(Constant(()))).splice

src/main/scala/scala/async/internal/StateAssigner.scala

+7-5
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
package scala.async.internal
66

77
private[async] final class StateAssigner {
8-
private var current = -1
8+
private var current = StateAssigner.Initial
99

10-
def nextState(): Int = {
11-
current += 1
12-
current
13-
}
10+
def nextState(): Int =
11+
try current finally current += 1
1412
}
13+
14+
object StateAssigner {
15+
final val Initial = 0
16+
}

src/main/scala/scala/async/internal/TransformUtils.scala

+1-4
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ private[async] trait TransformUtils {
4949

5050
private def isByName(fun: Tree): ((Int, Int) => Boolean) = {
5151
if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true
52+
else if (fun.tpe == null) (x, y) => false
5253
else {
5354
val paramss = fun.tpe.paramss
5455
val byNamess = paramss.map(_.map(_.isByNameParam))
@@ -72,10 +73,6 @@ private[async] trait TransformUtils {
7273
self.splice.contains(elem.splice)
7374
}
7475

75-
def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
76-
self.splice.apply(arg.splice)
77-
}
78-
7976
def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
8077
self.splice == other.splice
8178
}

0 commit comments

Comments
 (0)