Skip to content

Commit 3e20051

Browse files
authored
Merge pull request #15746 from dotty-staging/fix-level-checking
Do level checking on instantiation
2 parents 8a7c84c + f01abfb commit 3e20051

19 files changed

+288
-65
lines changed

compiler/src/dotty/tools/dotc/config/Config.scala

+7-4
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,12 @@ object Config {
227227
*/
228228
inline val reuseSymDenotations = true
229229

230-
/** If true, check levels of type variables and create fresh ones as needed.
231-
* This is necessary for soundness (see 3ab18a9), but also causes several
232-
* regressions that should be fixed before turning this on.
230+
/** If `checkLevelsOnConstraints` is true, check levels of type variables
231+
* and create fresh ones as needed when bounds are first entered intot he constraint.
232+
* If `checkLevelsOnInstantiation` is true, allow level-incorrect constraints but
233+
* fix levels on type variable instantiation.
233234
*/
234-
inline val checkLevels = false
235+
inline val checkLevelsOnConstraints = false
236+
inline val checkLevelsOnInstantiation = true
237+
235238
}

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+127-16
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import config.Printers.typr
1212
import typer.ProtoTypes.{newTypeVar, representedParamRef}
1313
import UnificationDirection.*
1414
import NameKinds.AvoidNameKind
15+
import util.SimpleIdentitySet
1516

1617
/** Methods for adding constraints and solving them.
1718
*
@@ -74,7 +75,43 @@ trait ConstraintHandling {
7475
protected def necessaryConstraintsOnly(using Context): Boolean =
7576
ctx.mode.is(Mode.GadtConstraintInference) || myNecessaryConstraintsOnly
7677

77-
protected var trustBounds = true
78+
/** If `trustBounds = false` we perform comparisons in a pessimistic way as follows:
79+
* Given an abstract type `A >: L <: H`, a subtype comparison of any type
80+
* with `A` will compare against both `L` and `H`. E.g.
81+
*
82+
* T <:< A if T <:< L and T <:< H
83+
* A <:< T if L <:< T and H <:< T
84+
*
85+
* This restricted form makes sure we don't "forget" types when forming
86+
* unions and intersections with abstract types that have bad bounds. E.g.
87+
* the following example from neg/i8900.scala that @smarter came up with:
88+
* We have a type variable X with constraints
89+
*
90+
* X >: 1, X >: x.M
91+
*
92+
* where `x` is a locally nested variable and `x.M` has bad bounds
93+
*
94+
* x.M >: Int | String <: Int & String
95+
*
96+
* If we trust bounds, then the lower bound of `X` is `x.M` since `x.M >: 1`.
97+
* Then even if we correct levels on instantiation to eliminate the local `x`,
98+
* it is alreay too late, we'd get `Int & String` as instance, which does not
99+
* satisfy the original constraint `X >: 1`.
100+
*
101+
* But if `trustBounds` is false, we do not conclude the `x.M >: 1` since
102+
* we compare both bounds and the upper bound `Int & String` is not a supertype
103+
* of `1`. So the lower bound is `1 | x.M` and when we level-avoid that we
104+
* get `1 | Int & String`, which simplifies to `Int`.
105+
*/
106+
private var myTrustBounds = true
107+
108+
inline def withUntrustedBounds(op: => Type): Type =
109+
val saved = myTrustBounds
110+
myTrustBounds = false
111+
try op finally myTrustBounds = saved
112+
113+
def trustBounds: Boolean =
114+
!Config.checkLevelsOnInstantiation || myTrustBounds
78115

79116
def checkReset() =
80117
assert(addConstraintInvocations == 0)
@@ -97,7 +134,7 @@ trait ConstraintHandling {
97134
level <= maxLevel
98135
|| ctx.isAfterTyper || !ctx.typerState.isCommittable // Leaks in these cases shouldn't break soundness
99136
|| level == Int.MaxValue // See `nestingLevel` above.
100-
|| !Config.checkLevels
137+
|| !Config.checkLevelsOnConstraints
101138

102139
/** If `param` is nested deeper than `maxLevel`, try to instantiate it to a
103140
* fresh type variable of level `maxLevel` and return the new variable.
@@ -262,16 +299,14 @@ trait ConstraintHandling {
262299
// If `isUpper` is true, ensure that `param <: `bound`, otherwise ensure
263300
// that `param >: bound`.
264301
val narrowedBounds =
265-
val savedHomogenizeArgs = homogenizeArgs
266-
val savedTrustBounds = trustBounds
302+
val saved = homogenizeArgs
267303
homogenizeArgs = Config.alignArgsInAnd
268304
try
269-
trustBounds = false
270-
if isUpper then oldBounds.derivedTypeBounds(lo, hi & bound)
271-
else oldBounds.derivedTypeBounds(lo | bound, hi)
305+
withUntrustedBounds(
306+
if isUpper then oldBounds.derivedTypeBounds(lo, hi & bound)
307+
else oldBounds.derivedTypeBounds(lo | bound, hi))
272308
finally
273-
homogenizeArgs = savedHomogenizeArgs
274-
trustBounds = savedTrustBounds
309+
homogenizeArgs = saved
275310
//println(i"narrow bounds for $param from $oldBounds to $narrowedBounds")
276311
val c1 = constraint.updateEntry(param, narrowedBounds)
277312
(c1 eq constraint)
@@ -431,24 +466,98 @@ trait ConstraintHandling {
431466
}
432467
}
433468

469+
/** Fix instance type `tp` by avoidance so that it does not contain references
470+
* to types at level > `maxLevel`.
471+
* @param tp the type to be fixed
472+
* @param fromBelow whether type was obtained from lower bound
473+
* @param maxLevel the maximum level of references allowed
474+
* @param param the parameter that was instantiated
475+
*/
476+
private def fixLevels(tp: Type, fromBelow: Boolean, maxLevel: Int, param: TypeParamRef)(using Context) =
477+
478+
def needsFix(tp: NamedType) =
479+
(tp.prefix eq NoPrefix) && tp.symbol.nestingLevel > maxLevel
480+
481+
/** An accumulator that determines whether levels need to be fixed
482+
* and computes on the side sets of nested type variables that need
483+
* to be instantiated.
484+
*/
485+
class NeedsLeveling extends TypeAccumulator[Boolean]:
486+
if !fromBelow then variance = -1
487+
488+
/** Nested type variables that should be instiated to theor lower (respoctively
489+
* upper) bounds.
490+
*/
491+
var nestedVarsLo, nestedVarsHi: SimpleIdentitySet[TypeVar] = SimpleIdentitySet.empty
492+
493+
def apply(need: Boolean, tp: Type) =
494+
need || tp.match
495+
case tp: NamedType =>
496+
needsFix(tp)
497+
|| !stopBecauseStaticOrLocal(tp) && apply(need, tp.prefix)
498+
case tp: TypeVar =>
499+
val inst = tp.instanceOpt
500+
if inst.exists then apply(need, inst)
501+
else if tp.nestingLevel > maxLevel then
502+
if variance > 0 then nestedVarsLo += tp
503+
else if variance < 0 then nestedVarsHi += tp
504+
else
505+
// For invariant type variables, we use a different strategy.
506+
// Rather than instantiating to a bound and then propagating in an
507+
// AvoidMap, change the nesting level of an invariant type
508+
// variable to `maxLevel`. This means that the type variable will be
509+
// instantiated later to a less nested type. If there are other references
510+
// to the same type variable that do not come from the type undergoing
511+
// `fixLevels`, this could lead to coarser types. But it has the potential
512+
// to give a better approximation for the current type, since it avoids forming
513+
// a Range in invariant position, which can lead to very coarse types further out.
514+
constr.println(i"widening nesting level of type variable $tp from ${tp.nestingLevel} to $maxLevel")
515+
ctx.typerState.setNestingLevel(tp, maxLevel)
516+
true
517+
else false
518+
case _ =>
519+
foldOver(need, tp)
520+
end NeedsLeveling
521+
522+
class LevelAvoidMap extends TypeOps.AvoidMap:
523+
if !fromBelow then variance = -1
524+
def toAvoid(tp: NamedType) = needsFix(tp)
525+
526+
if !Config.checkLevelsOnInstantiation || ctx.isAfterTyper then tp
527+
else
528+
val needsLeveling = NeedsLeveling()
529+
if needsLeveling(false, tp) then
530+
typr.println(i"instance $tp for $param needs leveling to $maxLevel, nested = ${needsLeveling.nestedVarsLo.toList} | ${needsLeveling.nestedVarsHi.toList}")
531+
needsLeveling.nestedVarsLo.foreach(_.instantiate(fromBelow = true))
532+
needsLeveling.nestedVarsHi.foreach(_.instantiate(fromBelow = false))
533+
LevelAvoidMap()(tp)
534+
else tp
535+
end fixLevels
536+
434537
/** Solve constraint set for given type parameter `param`.
435538
* If `fromBelow` is true the parameter is approximated by its lower bound,
436539
* otherwise it is approximated by its upper bound, unless the upper bound
437540
* contains a reference to the parameter itself (such occurrences can arise
438541
* for F-bounded types, `addOneBound` ensures that they never occur in the
439542
* lower bound).
543+
* The solved type is not allowed to contain references to types nested deeper
544+
* than `maxLevel`.
440545
* Wildcard types in bounds are approximated by their upper or lower bounds.
441546
* The constraint is left unchanged.
442547
* @return the instantiating type
443548
* @pre `param` is in the constraint's domain.
444549
*/
445-
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
550+
final def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type =
446551
constraint.entry(param) match
447552
case entry: TypeBounds =>
448553
val useLowerBound = fromBelow || param.occursIn(entry.hi)
449-
val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
450-
typr.println(s"approx ${param.show}, from below = $fromBelow, inst = ${inst.show}")
451-
inst
554+
val rawInst = withUntrustedBounds(
555+
if useLowerBound then fullLowerBound(param) else fullUpperBound(param))
556+
val levelInst = fixLevels(rawInst, fromBelow, maxLevel, param)
557+
if levelInst ne rawInst then
558+
typr.println(i"level avoid for $maxLevel: $rawInst --> $levelInst")
559+
typr.println(i"approx $param, from below = $fromBelow, inst = $levelInst")
560+
levelInst
452561
case inst =>
453562
assert(inst.exists, i"param = $param\nconstraint = $constraint")
454563
inst
@@ -560,9 +669,11 @@ trait ConstraintHandling {
560669
* lower bounds; otherwise it is the glb of its upper bounds. However,
561670
* a lower bound instantiation can be a singleton type only if the upper bound
562671
* is also a singleton type.
672+
* The instance type is not allowed to contain references to types nested deeper
673+
* than `maxLevel`.
563674
*/
564-
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
565-
val approx = approximation(param, fromBelow).simplified
675+
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
676+
val approx = approximation(param, fromBelow, maxLevel).simplified
566677
if fromBelow then
567678
val widened = widenInferred(approx, param)
568679
// Widening can add extra constraints, in particular the widened type might
@@ -572,7 +683,7 @@ trait ConstraintHandling {
572683
// (we do not check for non-toplevel occurences: those should never occur
573684
// since `addOneBound` disallows recursive lower bounds).
574685
if constraint.occursAtToplevel(param, widened) then
575-
instanceType(param, fromBelow)
686+
instanceType(param, fromBelow, maxLevel)
576687
else
577688
widened
578689
else

compiler/src/dotty/tools/dotc/core/Contexts.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ object Contexts {
165165
protected def scope_=(scope: Scope): Unit = _scope = scope
166166
final def scope: Scope = _scope
167167

168-
/** The current type comparer */
168+
/** The current typerstate */
169169
private var _typerState: TyperState = _
170170
protected def typerState_=(typerState: TyperState): Unit = _typerState = typerState
171171
final def typerState: TyperState = _typerState

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ sealed abstract class GadtConstraint extends Showable {
4747
def isNarrowing: Boolean
4848

4949
/** See [[ConstraintHandling.approximation]] */
50-
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type
50+
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type
5151

5252
def symbols: List[Symbol]
5353

@@ -205,9 +205,9 @@ final class ProperGadtConstraint private(
205205

206206
def isNarrowing: Boolean = wasConstrained
207207

208-
override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = {
208+
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
209209
val res =
210-
approximation(tvarOrError(sym).origin, fromBelow = fromBelow) match
210+
approximation(tvarOrError(sym).origin, fromBelow, maxLevel) match
211211
case tpr: TypeParamRef =>
212212
// Here we do externalization when the returned type is a TypeParamRef,
213213
// b/c ConstraintHandling.approximation may return internal types when
@@ -317,7 +317,7 @@ final class ProperGadtConstraint private(
317317
override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
318318
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")
319319

320-
override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
320+
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
321321

322322
override def symbols: List[Symbol] = Nil
323323

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -2869,11 +2869,11 @@ object TypeComparer {
28692869
def subtypeCheckInProgress(using Context): Boolean =
28702870
comparing(_.subtypeCheckInProgress)
28712871

2872-
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
2873-
comparing(_.instanceType(param, fromBelow))
2872+
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
2873+
comparing(_.instanceType(param, fromBelow, maxLevel))
28742874

2875-
def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
2876-
comparing(_.approximation(param, fromBelow))
2875+
def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
2876+
comparing(_.approximation(param, fromBelow, maxLevel))
28772877

28782878
def bounds(param: TypeParamRef)(using Context): TypeBounds =
28792879
comparing(_.bounds(param))
@@ -2953,7 +2953,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
29532953
case param @ TypeParamRef(b, n) if b eq caseLambda =>
29542954
insts(n) =
29552955
if canApprox then
2956-
approximation(param, fromBelow = variance >= 0).simplified
2956+
approximation(param, fromBelow = variance >= 0, Int.MaxValue).simplified
29572957
else constraint.entry(param) match
29582958
case entry: TypeBounds =>
29592959
val lo = fullLowerBound(param)

compiler/src/dotty/tools/dotc/core/TyperState.scala

+36-11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import config.Config
1010
import config.Printers.constr
1111
import collection.mutable
1212
import java.lang.ref.WeakReference
13-
import util.Stats
13+
import util.{Stats, SimpleIdentityMap}
1414
import Decorators._
1515

1616
import scala.annotation.internal.sharable
@@ -23,24 +23,26 @@ object TyperState {
2323
.setReporter(new ConsoleReporter())
2424
.setCommittable(true)
2525

26-
opaque type Snapshot = (Constraint, TypeVars, TypeVars)
26+
type LevelMap = SimpleIdentityMap[TypeVar, Integer]
27+
28+
opaque type Snapshot = (Constraint, TypeVars, LevelMap)
2729

2830
extension (ts: TyperState)
2931
def snapshot()(using Context): Snapshot =
30-
var previouslyInstantiated: TypeVars = SimpleIdentitySet.empty
31-
for tv <- ts.ownedVars do if tv.inst.exists then previouslyInstantiated += tv
32-
(ts.constraint, ts.ownedVars, previouslyInstantiated)
32+
(ts.constraint, ts.ownedVars, ts.upLevels)
3333

3434
def resetTo(state: Snapshot)(using Context): Unit =
35-
val (c, tvs, previouslyInstantiated) = state
36-
for tv <- tvs do
37-
if tv.inst.exists && !previouslyInstantiated.contains(tv) then
35+
val (constraint, ownedVars, upLevels) = state
36+
for tv <- ownedVars do
37+
if !ts.ownedVars.contains(tv) then // tv has been instantiated
3838
tv.resetInst(ts)
39-
ts.ownedVars = tvs
40-
ts.constraint = c
39+
ts.constraint = constraint
40+
ts.ownedVars = ownedVars
41+
ts.upLevels = upLevels
4142
}
4243

4344
class TyperState() {
45+
import TyperState.LevelMap
4446

4547
private var myId: Int = _
4648
def id: Int = myId
@@ -89,6 +91,8 @@ class TyperState() {
8991
def ownedVars: TypeVars = myOwnedVars
9092
def ownedVars_=(vs: TypeVars): Unit = myOwnedVars = vs
9193

94+
private var upLevels: LevelMap = _
95+
9296
/** Initializes all fields except reporter, isCommittable, which need to be
9397
* set separately.
9498
*/
@@ -99,20 +103,35 @@ class TyperState() {
99103
this.myConstraint = constraint
100104
this.previousConstraint = constraint
101105
this.myOwnedVars = SimpleIdentitySet.empty
106+
this.upLevels = SimpleIdentityMap.empty
102107
this.isCommitted = false
103108
this
104109

105110
/** A fresh typer state with the same constraint as this one. */
106111
def fresh(reporter: Reporter = StoreReporter(this.reporter, fromTyperState = true),
107112
committable: Boolean = this.isCommittable): TyperState =
108113
util.Stats.record("TyperState.fresh")
109-
TyperState().init(this, this.constraint)
114+
val ts = TyperState().init(this, this.constraint)
110115
.setReporter(reporter)
111116
.setCommittable(committable)
117+
ts.upLevels = upLevels
118+
ts
112119

113120
/** The uninstantiated variables */
114121
def uninstVars: collection.Seq[TypeVar] = constraint.uninstVars
115122

123+
/** The nestingLevel of `tv` in this typer state */
124+
def nestingLevel(tv: TypeVar): Int =
125+
val own = upLevels(tv)
126+
if own == null then tv.initNestingLevel else own.intValue()
127+
128+
/** Set the nestingLevel of `tv` in this typer state
129+
* @pre this level must be smaller than `tv.initNestingLevel`
130+
*/
131+
def setNestingLevel(tv: TypeVar, level: Int) =
132+
assert(level < tv.initNestingLevel)
133+
upLevels = upLevels.updated(tv, level)
134+
116135
/** The closest ancestor of this typer state (including possibly this typer state itself)
117136
* which is not yet committed, or which does not have a parent.
118137
*/
@@ -164,6 +183,12 @@ class TyperState() {
164183
if !ownedVars.isEmpty then ownedVars.foreach(targetState.includeVar)
165184
else
166185
targetState.mergeConstraintWith(this)
186+
187+
upLevels.foreachBinding { (tv, level) =>
188+
if level < targetState.nestingLevel(tv) then
189+
targetState.setNestingLevel(tv, level)
190+
}
191+
167192
targetState.gc()
168193
isCommitted = true
169194
ownedVars = SimpleIdentitySet.empty

0 commit comments

Comments
 (0)