Skip to content

Commit 5ed779c

Browse files
committed
Make sure parameters are not used again after they are consumed
1 parent c45bb88 commit 5ed779c

File tree

4 files changed

+239
-19
lines changed

4 files changed

+239
-19
lines changed

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

Lines changed: 162 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import config.Printers.capt
1212
import StdNames.nme
1313
import util.{SimpleIdentitySet, EqHashMap, SrcPos}
1414
import tpd.*
15+
import reflect.ClassTag
1516

1617
object SepChecker:
1718

@@ -39,14 +40,92 @@ object SepChecker:
3940
case _ => NoSymbol
4041
end TypeKind
4142

43+
/** A class for segmented sets of consumed references.
44+
* References are associated with the source positions where they first appeared.
45+
* References are compared with `eq`.
46+
*/
47+
abstract class ConsumedSet:
48+
/** The references in the set. The array should be treated as immutable in client code */
49+
def refs: Array[CaptureRef]
50+
51+
/** The associated source positoons. The array should be treated as immutable in client code */
52+
def locs: Array[SrcPos]
53+
54+
/** The number of references in the set */
55+
def size: Int
56+
57+
def toMap: Map[CaptureRef, SrcPos] = refs.take(size).zip(locs).toMap
58+
59+
def show(using Context) =
60+
s"[${toMap.map((ref, loc) => i"$ref -> $loc").toList}]"
61+
end ConsumedSet
62+
63+
/** A fixed consumed set consisting of the given references `refs` and
64+
* associated source positions `locs`
65+
*/
66+
class ConstConsumedSet(val refs: Array[CaptureRef], val locs: Array[SrcPos]) extends ConsumedSet:
67+
def size = refs.size
68+
69+
/** A mutable consumed set, which is initially empty */
70+
class MutConsumedSet extends ConsumedSet:
71+
var refs: Array[CaptureRef] = new Array(4)
72+
var locs: Array[SrcPos] = new Array(4)
73+
var size = 0
74+
75+
private def double[T <: AnyRef : ClassTag](xs: Array[T]): Array[T] =
76+
val xs1 = new Array[T](xs.length * 2)
77+
xs.copyToArray(xs1)
78+
xs1
79+
80+
private def ensureCapacity(added: Int): Unit =
81+
if size + added > refs.length then
82+
refs = double(refs)
83+
locs = double(locs)
84+
85+
/** If `ref` is in the set, its associated source position, otherwise `null` */
86+
def get(ref: CaptureRef): SrcPos | Null =
87+
var i = 0
88+
while i < size && (refs(i) ne ref) do i += 1
89+
if i < size then locs(i) else null
90+
91+
/** If `ref` is not yet in the set, add it with given source position */
92+
def put(ref: CaptureRef, loc: SrcPos): Unit =
93+
if get(ref) == null then
94+
ensureCapacity(1)
95+
refs(size) = ref
96+
locs(size) = loc
97+
size += 1
98+
99+
/** Add all references with their associated positions from `that` which
100+
* are not yet in the set.
101+
*/
102+
def ++= (that: ConsumedSet): Unit =
103+
for i <- 0 until that.size do put(that.refs(i), that.locs(i))
104+
105+
/** Run `op` and return any new references it created in a separate `ConsumedSet`.
106+
* The current mutable set is reset to its state before `op` was run.
107+
*/
108+
def segment(op: => Unit): ConsumedSet =
109+
val start = size
110+
try
111+
op
112+
if size == start then EmptyConsumedSet
113+
else ConstConsumedSet(refs.slice(start, size), locs.slice(start, size))
114+
finally
115+
size = start
116+
117+
end MutConsumedSet
118+
119+
val EmptyConsumedSet = ConstConsumedSet(Array(), Array())
120+
42121
class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
43122
import checker.*
44123
import SepChecker.*
45124

46125
/** The set of capabilities that are hidden by a polymorphic result type
47126
* of some previous definition.
48127
*/
49-
private var defsShadow: Refs = SimpleIdentitySet.empty
128+
private var defsShadow: Refs = emptySet
50129

51130
/** A map from definitions to their internal result types.
52131
* Populated during separation checking traversal.
@@ -58,6 +137,16 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
58137
*/
59138
private var previousDefs: List[mutable.ListBuffer[ValOrDefDef]] = Nil
60139

140+
private var consumed: MutConsumedSet = MutConsumedSet()
141+
142+
private def withFreshConsumed(op: => Unit): Unit =
143+
val saved = consumed
144+
consumed = MutConsumedSet()
145+
op
146+
consumed = saved
147+
148+
private var openLabeled: List[(Name, mutable.ListBuffer[ConsumedSet])] = Nil
149+
61150
extension (refs: Refs)
62151
private def footprint(using Context): Refs =
63152
def recur(elems: Refs, newElems: List[CaptureRef]): Refs = newElems match
@@ -198,6 +287,19 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
198287
tree.srcPos)
199288
end sepUseError
200289

290+
def consumeError(ref: CaptureRef, loc: SrcPos, pos: SrcPos)(using Context): Unit =
291+
report.error(
292+
em"""Separation failure: Illegal access to $ref,
293+
|which was passed to a @consume parameter on line ${loc.line + 1}
294+
|and therefore is no longer available.""",
295+
pos)
296+
297+
def consumeInLoopError(ref: CaptureRef, pos: SrcPos)(using Context): Unit =
298+
report.error(
299+
em"""Separation failure: $ref appears in a loop,
300+
|therefore it cannot be passed to a @consume parameter.""",
301+
pos)
302+
201303
private def checkApply(fn: Tree, args: List[Tree], deps: collection.Map[Tree, List[Tree]])(using Context): Unit =
202304
val fnCaptures = methPart(fn) match
203305
case Select(qual, _) => qual.nuType.captureSet
@@ -240,6 +342,9 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
240342
val overlap = defUseOverlap(defsShadow, usedFootprint, tree.symbol)
241343
if !overlap.isEmpty then
242344
sepUseError(tree, usedFootprint, overlap)
345+
for ref <- used.elems do
346+
val pos = consumed.get(ref)
347+
if pos != null then consumeError(ref, pos, tree.srcPos)
243348

244349
def checkType(tpt: Tree, sym: Symbol)(using Context): Unit =
245350
checkType(tpt.nuType, tpt.srcPos,
@@ -383,10 +488,11 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
383488
checkRefs(toCheck, i"$typeDescr type $tpe hides")
384489
case TypeKind.Argument(arg) =>
385490
if tpe.hasAnnotation(defn.ConsumeAnnot) then
386-
val capts = captures(arg)
387-
def descr(verb: String) = i"argument to @consume parameter with type ${arg.nuType} $verb"
388-
checkRefs(capts.footprint, descr("refers to"))
389-
checkRefs(capts.hidden.footprint, descr("hides"))
491+
val capts = captures(arg).footprint
492+
checkRefs(capts, i"argument to @consume parameter with type ${arg.nuType} refers to")
493+
for ref <- capts do
494+
if !ref.derivesFrom(defn.Caps_SharedCapability) then
495+
consumed.put(ref, arg.srcPos)
390496

391497
if !tpe.hasAnnotation(defn.UntrackedCapturesAnnot) then
392498
traverse(Captures.None, tpe)
@@ -435,35 +541,72 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
435541
case tree: Apply => tree.symbol == defn.Caps_unsafeAssumeSeparate
436542
case _ => false
437543

544+
def checkValOrDefDef(tree: ValOrDefDef)(using Context): Unit =
545+
if !tree.symbol.isOneOf(TermParamOrAccessor) && !isUnsafeAssumeSeparate(tree.rhs) then
546+
checkType(tree.tpt, tree.symbol)
547+
if previousDefs.nonEmpty then
548+
capt.println(i"sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}")
549+
defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
550+
resultType(tree.symbol) = tree.tpt.nuType
551+
previousDefs.head += tree
552+
438553
def traverse(tree: Tree)(using Context): Unit =
439554
if isUnsafeAssumeSeparate(tree) then return
440555
checkUse(tree)
441556
tree match
442557
case tree: GenericApply =>
558+
traverseChildren(tree)
443559
tree.tpe match
444560
case _: MethodOrPoly =>
445561
case _ => traverseApply(tree, Nil)
446-
traverseChildren(tree)
447562
case tree: Block =>
448563
val saved = defsShadow
449564
previousDefs = mutable.ListBuffer() :: previousDefs
450565
try traverseChildren(tree)
451566
finally
452567
previousDefs = previousDefs.tail
453568
defsShadow = saved
454-
case tree: ValOrDefDef =>
569+
case tree: ValDef =>
455570
traverseChildren(tree)
456-
if !tree.symbol.isOneOf(TermParamOrAccessor) && !isUnsafeAssumeSeparate(tree.rhs) then
457-
checkType(tree.tpt, tree.symbol)
458-
if previousDefs.nonEmpty then
459-
capt.println(i"sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}")
460-
defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
461-
resultType(tree.symbol) = tree.tpt.nuType
462-
previousDefs.head += tree
571+
checkValOrDefDef(tree)
572+
case tree: DefDef =>
573+
withFreshConsumed:
574+
traverseChildren(tree)
575+
checkValOrDefDef(tree)
576+
case If(cond, thenp, elsep) =>
577+
traverse(cond)
578+
val thenConsumed = consumed.segment(traverse(thenp))
579+
val elseConsumed = consumed.segment(traverse(elsep))
580+
consumed ++= thenConsumed
581+
consumed ++= elseConsumed
582+
case tree @ Labeled(bind, expr) =>
583+
val consumedBuf = mutable.ListBuffer[ConsumedSet]()
584+
openLabeled = (bind.name, consumedBuf) :: openLabeled
585+
traverse(expr)
586+
for cs <- consumedBuf do consumed ++= cs
587+
openLabeled = openLabeled.tail
588+
case Return(expr, from) =>
589+
val retConsumed = consumed.segment(traverse(expr))
590+
from match
591+
case Ident(name) =>
592+
for (lbl, consumedBuf) <- openLabeled do
593+
if lbl == name then
594+
consumedBuf += retConsumed
595+
case _ =>
596+
case Match(sel, cases) =>
597+
// Matches without returns might still be kept after pattern matching to
598+
// encode table switches.
599+
traverse(sel)
600+
val caseConsumed = for cas <- cases yield consumed.segment(traverse(cas))
601+
caseConsumed.foreach(consumed ++= _)
602+
case tree: TypeDef if tree.symbol.isClass =>
603+
withFreshConsumed:
604+
traverseChildren(tree)
605+
case tree: WhileDo =>
606+
val loopConsumed = consumed.segment(traverseChildren(tree))
607+
if loopConsumed.size != 0 then
608+
val (ref, pos) = loopConsumed.toMap.head
609+
consumeInLoopError(ref, pos)
463610
case _ =>
464611
traverseChildren(tree)
465-
end SepChecker
466-
467-
468-
469-
612+
end SepChecker
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
-- Error: tests/neg-custom-args/captures/linear-buffer.scala:13:17 -----------------------------------------------------
2+
13 | val buf3 = app(buf, 3) // error
3+
| ^^^
4+
| Separation failure: Illegal access to (buf : Buffer[Int]^),
5+
| which was passed to a @consume parameter on line 11
6+
| and therefore is no longer available.
7+
-- Error: tests/neg-custom-args/captures/linear-buffer.scala:20:17 -----------------------------------------------------
8+
20 | val buf3 = app(buf1, 4) // error
9+
| ^^^^
10+
| Separation failure: Illegal access to (buf1 : Buffer[Int]^),
11+
| which was passed to a @consume parameter on line 18
12+
| and therefore is no longer available.
13+
-- Error: tests/neg-custom-args/captures/linear-buffer.scala:28:17 -----------------------------------------------------
14+
28 | val buf3 = app(buf1, 4) // error
15+
| ^^^^
16+
| Separation failure: Illegal access to (buf1 : Buffer[Int]^),
17+
| which was passed to a @consume parameter on line 25
18+
| and therefore is no longer available.
19+
-- Error: tests/neg-custom-args/captures/linear-buffer.scala:38:17 -----------------------------------------------------
20+
38 | val buf3 = app(buf1, 4) // error
21+
| ^^^^
22+
| Separation failure: Illegal access to (buf1 : Buffer[Int]^),
23+
| which was passed to a @consume parameter on line 33
24+
| and therefore is no longer available.
25+
-- Error: tests/neg-custom-args/captures/linear-buffer.scala:42:8 ------------------------------------------------------
26+
42 | app(buf, 1) // error
27+
| ^^^
28+
| Separation failure: (buf : Buffer[Int]^) appears in a loop,
29+
| therefore it cannot be passed to a @consume parameter.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import caps.{cap, consume, Mutable}
2+
import language.experimental.captureChecking
3+
4+
class Buffer[T] extends Mutable:
5+
mut def append(x: T): Buffer[T]^ = ???
6+
7+
def app[T](@consume buf: Buffer[T]^, elem: T): Buffer[T]^ =
8+
buf.append(elem)
9+
10+
def Test(@consume buf: Buffer[Int]^) =
11+
val buf1: Buffer[Int]^ = app(buf, 1)
12+
val buf2 = app(buf1, 2) // OK
13+
val buf3 = app(buf, 3) // error
14+
15+
def Test2(@consume buf: Buffer[Int]^) =
16+
val buf1: Buffer[Int]^ = app(buf, 1)
17+
val buf2 =
18+
if ??? then app(buf1, 2) // OK
19+
else app(buf1, 3) // OK
20+
val buf3 = app(buf1, 4) // error
21+
22+
def Test3(@consume buf: Buffer[Int]^) =
23+
val buf1: Buffer[Int]^ = app(buf, 1)
24+
val buf2 = (??? : Int) match
25+
case 1 => app(buf1, 2) // OK
26+
case 2 => app(buf1, 2)
27+
case _ => app(buf1, 3)
28+
val buf3 = app(buf1, 4) // error
29+
30+
def Test4(@consume buf: Buffer[Int]^) =
31+
val buf1: Buffer[Int]^ = app(buf, 1)
32+
val buf2 = (??? : Int) match
33+
case 1 => app(buf1, 2) // OK
34+
case 2 => app(buf1, 2)
35+
case 3 => app(buf1, 3)
36+
case 4 => app(buf1, 4)
37+
case 5 => app(buf1, 5)
38+
val buf3 = app(buf1, 4) // error
39+
40+
def Test5(@consume buf: Buffer[Int]^) =
41+
while true do
42+
app(buf, 1) // error

tests/neg-custom-args/captures/non-local-consume.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,9 @@ def f4(@consume buf: Buffer^): Buffer^ =
2121
def g(): Buffer^ = buf1 // error
2222
g()
2323

24+
def f5(@consume buf: Buffer^): Unit =
25+
val buf1: Buffer^ = buf
26+
def g(): Unit = cc(buf1) // error
27+
g()
28+
29+
def cc(@consume buf: Buffer^): Unit = ()

0 commit comments

Comments
 (0)