@@ -12,6 +12,7 @@ import config.Printers.capt
12
12
import StdNames .nme
13
13
import util .{SimpleIdentitySet , EqHashMap , SrcPos }
14
14
import tpd .*
15
+ import reflect .ClassTag
15
16
16
17
object SepChecker :
17
18
@@ -39,14 +40,92 @@ object SepChecker:
39
40
case _ => NoSymbol
40
41
end TypeKind
41
42
43
+ /** A class for segmented sets of consumed references.
44
+ * References are associated with the source positions where they first appeared.
45
+ * References are compared with `eq`.
46
+ */
47
+ abstract class ConsumedSet :
48
+ /** The references in the set. The array should be treated as immutable in client code */
49
+ def refs : Array [CaptureRef ]
50
+
51
+ /** The associated source positoons. The array should be treated as immutable in client code */
52
+ def locs : Array [SrcPos ]
53
+
54
+ /** The number of references in the set */
55
+ def size : Int
56
+
57
+ def toMap : Map [CaptureRef , SrcPos ] = refs.take(size).zip(locs).toMap
58
+
59
+ def show (using Context ) =
60
+ s " [ ${toMap.map((ref, loc) => i " $ref -> $loc" ).toList}] "
61
+ end ConsumedSet
62
+
63
+ /** A fixed consumed set consisting of the given references `refs` and
64
+ * associated source positions `locs`
65
+ */
66
+ class ConstConsumedSet (val refs : Array [CaptureRef ], val locs : Array [SrcPos ]) extends ConsumedSet :
67
+ def size = refs.size
68
+
69
+ /** A mutable consumed set, which is initially empty */
70
+ class MutConsumedSet extends ConsumedSet :
71
+ var refs : Array [CaptureRef ] = new Array (4 )
72
+ var locs : Array [SrcPos ] = new Array (4 )
73
+ var size = 0
74
+
75
+ private def double [T <: AnyRef : ClassTag ](xs : Array [T ]): Array [T ] =
76
+ val xs1 = new Array [T ](xs.length * 2 )
77
+ xs.copyToArray(xs1)
78
+ xs1
79
+
80
+ private def ensureCapacity (added : Int ): Unit =
81
+ if size + added > refs.length then
82
+ refs = double(refs)
83
+ locs = double(locs)
84
+
85
+ /** If `ref` is in the set, its associated source position, otherwise `null` */
86
+ def get (ref : CaptureRef ): SrcPos | Null =
87
+ var i = 0
88
+ while i < size && (refs(i) ne ref) do i += 1
89
+ if i < size then locs(i) else null
90
+
91
+ /** If `ref` is not yet in the set, add it with given source position */
92
+ def put (ref : CaptureRef , loc : SrcPos ): Unit =
93
+ if get(ref) == null then
94
+ ensureCapacity(1 )
95
+ refs(size) = ref
96
+ locs(size) = loc
97
+ size += 1
98
+
99
+ /** Add all references with their associated positions from `that` which
100
+ * are not yet in the set.
101
+ */
102
+ def ++= (that : ConsumedSet ): Unit =
103
+ for i <- 0 until that.size do put(that.refs(i), that.locs(i))
104
+
105
+ /** Run `op` and return any new references it created in a separate `ConsumedSet`.
106
+ * The current mutable set is reset to its state before `op` was run.
107
+ */
108
+ def segment (op : => Unit ): ConsumedSet =
109
+ val start = size
110
+ try
111
+ op
112
+ if size == start then EmptyConsumedSet
113
+ else ConstConsumedSet (refs.slice(start, size), locs.slice(start, size))
114
+ finally
115
+ size = start
116
+
117
+ end MutConsumedSet
118
+
119
+ val EmptyConsumedSet = ConstConsumedSet (Array (), Array ())
120
+
42
121
class SepChecker (checker : CheckCaptures .CheckerAPI ) extends tpd.TreeTraverser :
43
122
import checker .*
44
123
import SepChecker .*
45
124
46
125
/** The set of capabilities that are hidden by a polymorphic result type
47
126
* of some previous definition.
48
127
*/
49
- private var defsShadow : Refs = SimpleIdentitySet .empty
128
+ private var defsShadow : Refs = emptySet
50
129
51
130
/** A map from definitions to their internal result types.
52
131
* Populated during separation checking traversal.
@@ -58,6 +137,16 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
58
137
*/
59
138
private var previousDefs : List [mutable.ListBuffer [ValOrDefDef ]] = Nil
60
139
140
+ private var consumed : MutConsumedSet = MutConsumedSet ()
141
+
142
+ private def withFreshConsumed (op : => Unit ): Unit =
143
+ val saved = consumed
144
+ consumed = MutConsumedSet ()
145
+ op
146
+ consumed = saved
147
+
148
+ private var openLabeled : List [(Name , mutable.ListBuffer [ConsumedSet ])] = Nil
149
+
61
150
extension (refs : Refs )
62
151
private def footprint (using Context ): Refs =
63
152
def recur (elems : Refs , newElems : List [CaptureRef ]): Refs = newElems match
@@ -198,6 +287,19 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
198
287
tree.srcPos)
199
288
end sepUseError
200
289
290
+ def consumeError (ref : CaptureRef , loc : SrcPos , pos : SrcPos )(using Context ): Unit =
291
+ report.error(
292
+ em """ Separation failure: Illegal access to $ref,
293
+ |which was passed to a @consume parameter on line ${loc.line + 1 }
294
+ |and therefore is no longer available. """ ,
295
+ pos)
296
+
297
+ def consumeInLoopError (ref : CaptureRef , pos : SrcPos )(using Context ): Unit =
298
+ report.error(
299
+ em """ Separation failure: $ref appears in a loop,
300
+ |therefore it cannot be passed to a @consume parameter. """ ,
301
+ pos)
302
+
201
303
private def checkApply (fn : Tree , args : List [Tree ], deps : collection.Map [Tree , List [Tree ]])(using Context ): Unit =
202
304
val fnCaptures = methPart(fn) match
203
305
case Select (qual, _) => qual.nuType.captureSet
@@ -240,6 +342,9 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
240
342
val overlap = defUseOverlap(defsShadow, usedFootprint, tree.symbol)
241
343
if ! overlap.isEmpty then
242
344
sepUseError(tree, usedFootprint, overlap)
345
+ for ref <- used.elems do
346
+ val pos = consumed.get(ref)
347
+ if pos != null then consumeError(ref, pos, tree.srcPos)
243
348
244
349
def checkType (tpt : Tree , sym : Symbol )(using Context ): Unit =
245
350
checkType(tpt.nuType, tpt.srcPos,
@@ -383,10 +488,11 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
383
488
checkRefs(toCheck, i " $typeDescr type $tpe hides " )
384
489
case TypeKind .Argument (arg) =>
385
490
if tpe.hasAnnotation(defn.ConsumeAnnot ) then
386
- val capts = captures(arg)
387
- def descr (verb : String ) = i " argument to @consume parameter with type ${arg.nuType} $verb"
388
- checkRefs(capts.footprint, descr(" refers to" ))
389
- checkRefs(capts.hidden.footprint, descr(" hides" ))
491
+ val capts = captures(arg).footprint
492
+ checkRefs(capts, i " argument to @consume parameter with type ${arg.nuType} refers to " )
493
+ for ref <- capts do
494
+ if ! ref.derivesFrom(defn.Caps_SharedCapability ) then
495
+ consumed.put(ref, arg.srcPos)
390
496
391
497
if ! tpe.hasAnnotation(defn.UntrackedCapturesAnnot ) then
392
498
traverse(Captures .None , tpe)
@@ -435,35 +541,72 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
435
541
case tree : Apply => tree.symbol == defn.Caps_unsafeAssumeSeparate
436
542
case _ => false
437
543
544
+ def checkValOrDefDef (tree : ValOrDefDef )(using Context ): Unit =
545
+ if ! tree.symbol.isOneOf(TermParamOrAccessor ) && ! isUnsafeAssumeSeparate(tree.rhs) then
546
+ checkType(tree.tpt, tree.symbol)
547
+ if previousDefs.nonEmpty then
548
+ capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
549
+ defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
550
+ resultType(tree.symbol) = tree.tpt.nuType
551
+ previousDefs.head += tree
552
+
438
553
def traverse (tree : Tree )(using Context ): Unit =
439
554
if isUnsafeAssumeSeparate(tree) then return
440
555
checkUse(tree)
441
556
tree match
442
557
case tree : GenericApply =>
558
+ traverseChildren(tree)
443
559
tree.tpe match
444
560
case _ : MethodOrPoly =>
445
561
case _ => traverseApply(tree, Nil )
446
- traverseChildren(tree)
447
562
case tree : Block =>
448
563
val saved = defsShadow
449
564
previousDefs = mutable.ListBuffer () :: previousDefs
450
565
try traverseChildren(tree)
451
566
finally
452
567
previousDefs = previousDefs.tail
453
568
defsShadow = saved
454
- case tree : ValOrDefDef =>
569
+ case tree : ValDef =>
455
570
traverseChildren(tree)
456
- if ! tree.symbol.isOneOf(TermParamOrAccessor ) && ! isUnsafeAssumeSeparate(tree.rhs) then
457
- checkType(tree.tpt, tree.symbol)
458
- if previousDefs.nonEmpty then
459
- capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
460
- defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
461
- resultType(tree.symbol) = tree.tpt.nuType
462
- previousDefs.head += tree
571
+ checkValOrDefDef(tree)
572
+ case tree : DefDef =>
573
+ withFreshConsumed :
574
+ traverseChildren(tree)
575
+ checkValOrDefDef(tree)
576
+ case If (cond, thenp, elsep) =>
577
+ traverse(cond)
578
+ val thenConsumed = consumed.segment(traverse(thenp))
579
+ val elseConsumed = consumed.segment(traverse(elsep))
580
+ consumed ++= thenConsumed
581
+ consumed ++= elseConsumed
582
+ case tree @ Labeled (bind, expr) =>
583
+ val consumedBuf = mutable.ListBuffer [ConsumedSet ]()
584
+ openLabeled = (bind.name, consumedBuf) :: openLabeled
585
+ traverse(expr)
586
+ for cs <- consumedBuf do consumed ++= cs
587
+ openLabeled = openLabeled.tail
588
+ case Return (expr, from) =>
589
+ val retConsumed = consumed.segment(traverse(expr))
590
+ from match
591
+ case Ident (name) =>
592
+ for (lbl, consumedBuf) <- openLabeled do
593
+ if lbl == name then
594
+ consumedBuf += retConsumed
595
+ case _ =>
596
+ case Match (sel, cases) =>
597
+ // Matches without returns might still be kept after pattern matching to
598
+ // encode table switches.
599
+ traverse(sel)
600
+ val caseConsumed = for cas <- cases yield consumed.segment(traverse(cas))
601
+ caseConsumed.foreach(consumed ++= _)
602
+ case tree : TypeDef if tree.symbol.isClass =>
603
+ withFreshConsumed :
604
+ traverseChildren(tree)
605
+ case tree : WhileDo =>
606
+ val loopConsumed = consumed.segment(traverseChildren(tree))
607
+ if loopConsumed.size != 0 then
608
+ val (ref, pos) = loopConsumed.toMap.head
609
+ consumeInLoopError(ref, pos)
463
610
case _ =>
464
611
traverseChildren(tree)
465
- end SepChecker
466
-
467
-
468
-
469
-
612
+ end SepChecker
0 commit comments