Skip to content

Commit 93e4416

Browse files
authored
Merge pull request #9400 from dotty-staging/simplify-typecomparer
Simplify ConstraintHandling
2 parents 6ada79c + d505ef3 commit 93e4416

File tree

5 files changed

+50
-58
lines changed

5 files changed

+50
-58
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+29-33
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,12 @@ import reporting.trace
2222
* By comparison: Constraint handlers are parts of type comparers and can use their functionality.
2323
* Constraint handlers update the current constraint as a side effect.
2424
*/
25-
trait ConstraintHandling[AbstractContext] {
25+
trait ConstraintHandling {
2626

2727
def constr: config.Printers.Printer = config.Printers.constr
2828

29-
def comparerCtx(using AbstractContext): Context
30-
31-
given (using AbstractContext) as Context = comparerCtx
32-
33-
protected def isSubType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean
34-
protected def isSameType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean
29+
protected def isSub(tp1: Type, tp2: Type)(using Context): Boolean
30+
protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean
3531

3632
protected def constraint: Constraint
3733
protected def constraint_=(c: Constraint): Unit
@@ -71,23 +67,23 @@ trait ConstraintHandling[AbstractContext] {
7167
case tp => tp
7268
}
7369

74-
def nonParamBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = constraint.nonParamBounds(param)
70+
def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = constraint.nonParamBounds(param)
7571

76-
def fullLowerBound(param: TypeParamRef)(implicit actx: AbstractContext): Type =
72+
def fullLowerBound(param: TypeParamRef)(using Context): Type =
7773
constraint.minLower(param).foldLeft(nonParamBounds(param).lo)(_ | _)
7874

79-
def fullUpperBound(param: TypeParamRef)(implicit actx: AbstractContext): Type =
75+
def fullUpperBound(param: TypeParamRef)(using Context): Type =
8076
constraint.minUpper(param).foldLeft(nonParamBounds(param).hi)(_ & _)
8177

8278
/** Full bounds of `param`, including other lower/upper params.
8379
*
8480
* Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds`
8581
* of some param when comparing types might lead to infinite recursion. Consider `bounds` instead.
8682
*/
87-
def fullBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds =
83+
def fullBounds(param: TypeParamRef)(using Context): TypeBounds =
8884
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))
8985

90-
protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using AbstractContext): Boolean =
86+
protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Boolean =
9187
if !constraint.contains(param) then true
9288
else if !isUpper && param.occursIn(bound)
9389
// We don't allow recursive lower bounds when defining a type,
@@ -121,11 +117,11 @@ trait ConstraintHandling[AbstractContext] {
121117
|| {
122118
constraint = c1
123119
val TypeBounds(lo, hi) = constraint.entry(param)
124-
isSubType(lo, hi)
120+
isSub(lo, hi)
125121
}
126122
end addOneBound
127123

128-
protected def addBoundTransitively(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean =
124+
protected def addBoundTransitively(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
129125

130126
/** Adjust the bound `tp` in the following ways:
131127
*
@@ -172,7 +168,7 @@ trait ConstraintHandling[AbstractContext] {
172168
.reporting(i"added $description = $result$location", constr)
173169
end addBoundTransitively
174170

175-
protected def addLess(p1: TypeParamRef, p2: TypeParamRef)(implicit actx: AbstractContext): Boolean = {
171+
protected def addLess(p1: TypeParamRef, p2: TypeParamRef)(using Context): Boolean = {
176172
def description = i"ordering $p1 <: $p2 to\n$constraint"
177173
val res =
178174
if (constraint.isLess(p2, p1)) unify(p2, p1)
@@ -195,7 +191,7 @@ trait ConstraintHandling[AbstractContext] {
195191
/** Make p2 = p1, transfer all bounds of p2 to p1
196192
* @pre less(p1)(p2)
197193
*/
198-
private def unify(p1: TypeParamRef, p2: TypeParamRef)(implicit actx: AbstractContext): Boolean = {
194+
private def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): Boolean = {
199195
constr.println(s"unifying $p1 $p2")
200196
assert(constraint.isLess(p1, p2))
201197
val down = constraint.exclusiveLower(p2, p1)
@@ -204,16 +200,16 @@ trait ConstraintHandling[AbstractContext] {
204200
val bounds = constraint.nonParamBounds(p1)
205201
val lo = bounds.lo
206202
val hi = bounds.hi
207-
isSubType(lo, hi) &&
203+
isSub(lo, hi) &&
208204
down.forall(addOneBound(_, hi, isUpper = true)) &&
209205
up.forall(addOneBound(_, lo, isUpper = false))
210206
}
211207

212-
protected def isSubType(tp1: Type, tp2: Type, whenFrozen: Boolean)(implicit actx: AbstractContext): Boolean =
208+
protected def isSubType(tp1: Type, tp2: Type, whenFrozen: Boolean)(using Context): Boolean =
213209
if (whenFrozen)
214210
isSubTypeWhenFrozen(tp1, tp2)
215211
else
216-
isSubType(tp1, tp2)
212+
isSub(tp1, tp2)
217213

218214
inline final def inFrozenConstraint[T](op: => T): T = {
219215
val savedFrozen = frozenConstraint
@@ -227,16 +223,16 @@ trait ConstraintHandling[AbstractContext] {
227223
}
228224
}
229225

230-
final def isSubTypeWhenFrozen(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean = inFrozenConstraint(isSubType(tp1, tp2))
231-
final def isSameTypeWhenFrozen(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean = inFrozenConstraint(isSameType(tp1, tp2))
226+
final def isSubTypeWhenFrozen(tp1: Type, tp2: Type)(using Context): Boolean = inFrozenConstraint(isSub(tp1, tp2))
227+
final def isSameTypeWhenFrozen(tp1: Type, tp2: Type)(using Context): Boolean = inFrozenConstraint(isSame(tp1, tp2))
232228

233229
/** Test whether the lower bounds of all parameters in this
234230
* constraint are a solution to the constraint.
235231
*/
236-
protected final def isSatisfiable(implicit actx: AbstractContext): Boolean =
232+
protected final def isSatisfiable(using Context): Boolean =
237233
constraint.forallParams { param =>
238234
val TypeBounds(lo, hi) = constraint.entry(param)
239-
isSubType(lo, hi) || {
235+
isSub(lo, hi) || {
240236
report.log(i"sub fail $lo <:< $hi")
241237
false
242238
}
@@ -253,7 +249,7 @@ trait ConstraintHandling[AbstractContext] {
253249
* @return the instantiating type
254250
* @pre `param` is in the constraint's domain.
255251
*/
256-
final def approximation(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = {
252+
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
257253
val replaceWildcards = new TypeMap {
258254
override def stopAtStatic = true
259255
def apply(tp: Type) = mapOver {
@@ -317,7 +313,7 @@ trait ConstraintHandling[AbstractContext] {
317313
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
318314
* as those could leak the annotation to users (see run/inferred-repeated-result).
319315
*/
320-
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =
316+
def widenInferred(inst: Type, bound: Type)(using Context): Type =
321317

322318
def dropSuperTraits(tp: Type): Type =
323319
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
@@ -380,7 +376,7 @@ trait ConstraintHandling[AbstractContext] {
380376
* a lower bound instantiation can be a singleton type only if the upper bound
381377
* is also a singleton type.
382378
*/
383-
def instanceType(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = {
379+
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
384380
val approx = approximation(param, fromBelow).simplified
385381
if (fromBelow)
386382
val widened = widenInferred(approx, param)
@@ -408,7 +404,7 @@ trait ConstraintHandling[AbstractContext] {
408404
* Both `c1` and `c2` are required to derive from constraint `pre`, without adding
409405
* any new type variables but possibly narrowing already registered ones with further bounds.
410406
*/
411-
protected final def subsumes(c1: Constraint, c2: Constraint, pre: Constraint)(implicit actx: AbstractContext): Boolean =
407+
protected final def subsumes(c1: Constraint, c2: Constraint, pre: Constraint)(using Context): Boolean =
412408
if (c2 eq pre) true
413409
else if (c1 eq pre) false
414410
else {
@@ -427,7 +423,7 @@ trait ConstraintHandling[AbstractContext] {
427423
}
428424

429425
/** The current bounds of type parameter `param` */
430-
def bounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = {
426+
def bounds(param: TypeParamRef)(using Context): TypeBounds = {
431427
val e = constraint.entry(param)
432428
if (e.exists) e.bounds
433429
else {
@@ -441,7 +437,7 @@ trait ConstraintHandling[AbstractContext] {
441437
* and propagate all bounds.
442438
* @param tvars See Constraint#add
443439
*/
444-
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(implicit actx: AbstractContext): Boolean =
440+
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
445441
checkPropagated(i"initialized $tl") {
446442
constraint = constraint.add(tl, tvars)
447443
tl.paramRefs.forall { param =>
@@ -470,7 +466,7 @@ trait ConstraintHandling[AbstractContext] {
470466
* This holds if `TypeVarsMissContext` is set unless `param` is a part
471467
* of a MatchType that is currently normalized.
472468
*/
473-
final def assumedTrue(param: TypeParamRef)(implicit actx: AbstractContext): Boolean =
469+
final def assumedTrue(param: TypeParamRef)(using Context): Boolean =
474470
ctx.mode.is(Mode.TypevarsMissContext) && (caseLambda `ne` param.binder)
475471

476472
/** Add constraint `param <: bound` if `fromBelow` is false, `param >: bound` otherwise.
@@ -480,7 +476,7 @@ trait ConstraintHandling[AbstractContext] {
480476
* not be AndTypes and lower bounds may not be OrTypes. This is assured by the
481477
* way isSubType is organized.
482478
*/
483-
protected def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(implicit actx: AbstractContext): Boolean =
479+
protected def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(using Context): Boolean =
484480

485481
/** When comparing lambdas we might get constraints such as
486482
* `A <: X0` or `A = List[X0]` where `A` is a constrained parameter
@@ -514,7 +510,7 @@ trait ConstraintHandling[AbstractContext] {
514510
case _: TypeBounds =>
515511
if (fromBelow) addLess(bound, param) else addLess(param, bound)
516512
case tp =>
517-
if (fromBelow) isSubType(bound, tp) else isSubType(tp, bound)
513+
if (fromBelow) isSub(bound, tp) else isSub(tp, bound)
518514
}
519515

520516
def kindCompatible(tp1: Type, tp2: Type): Boolean =
@@ -541,7 +537,7 @@ trait ConstraintHandling[AbstractContext] {
541537
end addConstraint
542538

543539
/** Check that constraint is fully propagated. See comment in Config.checkConstraintsPropagated */
544-
def checkPropagated(msg: => String)(result: Boolean)(implicit actx: AbstractContext): Boolean = {
540+
def checkPropagated(msg: => String)(result: Boolean)(using Context): Boolean = {
545541
if (Config.checkConstraintsPropagated && result && addConstraintInvocations == 0)
546542
inFrozenConstraint {
547543
for (p <- constraint.domainParams) {

compiler/src/dotty/tools/dotc/core/Contexts.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ object Contexts {
735735
store = initialStore
736736
.updated(settingsStateLoc, settingsGroup.defaultState)
737737
.updated(notNullInfosLoc, Nil)
738-
typeComparer = new TypeComparer(this)
738+
typeComparer = new TypeComparer(using this)
739739
searchHistory = new SearchRoot
740740
gadt = EmptyGadtConstraint
741741
}

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

+4-6
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ final class ProperGadtConstraint private(
6161
private var myConstraint: Constraint,
6262
private var mapping: SimpleIdentityMap[Symbol, TypeVar],
6363
private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
64-
) extends GadtConstraint with ConstraintHandling[Context] {
64+
) extends GadtConstraint with ConstraintHandling {
6565
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
6666

6767
def this() = this(
@@ -140,7 +140,7 @@ final class ProperGadtConstraint private(
140140
case tv: TypeVar => tv
141141
case inst =>
142142
gadts.println(i"instantiated: $sym -> $inst")
143-
return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst)
143+
return if (isUpper) isSub(inst, bound) else isSub(bound, inst)
144144
}
145145

146146
val internalizedBound = bound match {
@@ -217,13 +217,11 @@ final class ProperGadtConstraint private(
217217

218218
// ---- Protected/internal -----------------------------------------------
219219

220-
override def comparerCtx(using Context): Context = ctx
221-
222220
override protected def constraint = myConstraint
223221
override protected def constraint_=(c: Constraint) = myConstraint = c
224222

225-
override def isSubType(tp1: Type, tp2: Type)(using Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
226-
override def isSameType(tp1: Type, tp2: Type)(using Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
223+
override protected def isSub(tp1: Type, tp2: Type)(using Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
224+
override protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
227225

228226
override def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds =
229227
constraint.nonParamBounds(param) match {

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+15-17
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,10 @@ import typer.Applications.productSelectorTypes
2424
import reporting.trace
2525
import NullOpsDecorator.NullOps
2626

27-
final class AbsentContext
28-
object AbsentContext {
29-
implicit val absentContext: AbsentContext = new AbsentContext
30-
}
31-
3227
/** Provides methods to compare types.
3328
*/
34-
class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] with PatternTypeConstrainer {
29+
class TypeComparer(using val comparerCtx: Context) extends ConstraintHandling with PatternTypeConstrainer {
3530
import TypeComparer._
36-
def comparerCtx(using AbsentContext): Context = initctx
3731

3832
val state = ctx.typerState
3933
def constraint: Constraint = state.constraint
@@ -175,7 +169,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
175169
}
176170
}
177171

178-
def isSubType(tp1: Type, tp2: Type)(implicit nc: AbsentContext): Boolean = isSubType(tp1, tp2, FreshApprox)
172+
def isSubType(tp1: Type, tp2: Type): Boolean = isSubType(tp1, tp2, FreshApprox)
173+
174+
override protected def isSub(tp1: Type, tp2: Type)(using Context): Boolean = isSubType(tp1, tp2)
179175

180176
/** The inner loop of the isSubType comparison.
181177
* Recursive calls from recur should go to recur directly if the two types
@@ -1769,11 +1765,13 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
17691765
// Type equality =:=
17701766

17711767
/** Two types are the same if are mutual subtypes of each other */
1772-
def isSameType(tp1: Type, tp2: Type)(implicit nc: AbsentContext): Boolean =
1768+
def isSameType(tp1: Type, tp2: Type): Boolean =
17731769
if (tp1 eq NoType) false
17741770
else if (tp1 eq tp2) true
17751771
else isSubType(tp1, tp2) && isSubType(tp2, tp1)
17761772

1773+
override protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean = isSameType(tp1, tp2)
1774+
17771775
/** Same as `isSameType` but also can be applied to overloaded TermRefs, where
17781776
* two overloaded refs are the same if they have pairwise equal alternatives
17791777
*/
@@ -2215,7 +2213,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
22152213
}
22162214

22172215
/** A new type comparer of the same type as this one, using the given context. */
2218-
def copyIn(ctx: Context): TypeComparer = new TypeComparer(ctx)
2216+
def copyIn(ctx: Context): TypeComparer = new TypeComparer(using ctx)
22192217

22202218
// ----------- Diagnostics --------------------------------------------------
22212219

@@ -2469,7 +2467,7 @@ object TypeComparer {
24692467

24702468
/** Show trace of comparison operations when performing `op` */
24712469
def explaining[T](say: String => Unit)(op: Context ?=> T)(using Context): T = {
2472-
val nestedCtx = ctx.fresh.setTypeComparerFn(new ExplainingTypeComparer(_))
2470+
val nestedCtx = ctx.fresh.setTypeComparerFn(new ExplainingTypeComparer(using _))
24732471
val res = try { op(using nestedCtx) } finally { say(nestedCtx.typeComparer.lastTrace()) }
24742472
res
24752473
}
@@ -2482,17 +2480,17 @@ object TypeComparer {
24822480
}
24832481
}
24842482

2485-
class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
2483+
class TrackingTypeComparer(using Context) extends TypeComparer {
24862484
import state.constraint
24872485

24882486
val footprint: mutable.Set[Type] = mutable.Set[Type]()
24892487

2490-
override def bounds(param: TypeParamRef)(implicit nc: AbsentContext): TypeBounds = {
2488+
override def bounds(param: TypeParamRef)(using Context): TypeBounds = {
24912489
if (param.binder `ne` caseLambda) footprint += param
24922490
super.bounds(param)
24932491
}
24942492

2495-
override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit nc: AbsentContext): Boolean = {
2493+
override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Boolean = {
24962494
if (param.binder `ne` caseLambda) footprint += param
24972495
super.addOneBound(param, bound, isUpper)
24982496
}
@@ -2630,7 +2628,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
26302628
}
26312629

26322630
/** A type comparer that can record traces of subtype operations */
2633-
class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
2631+
class ExplainingTypeComparer(using Context) extends TypeComparer {
26342632
import TypeComparer._
26352633

26362634
private var indent = 0
@@ -2678,12 +2676,12 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
26782676
super.glb(tp1, tp2)
26792677
}
26802678

2681-
override def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(implicit nc: AbsentContext): Boolean =
2679+
override def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(using Context): Boolean =
26822680
traceIndented(i"add constraint $param ${if (fromBelow) ">:" else "<:"} $bound $frozenConstraint, constraint = ${ctx.typerState.constraint}") {
26832681
super.addConstraint(param, bound, fromBelow)
26842682
}
26852683

2686-
override def copyIn(ctx: Context): ExplainingTypeComparer = new ExplainingTypeComparer(ctx)
2684+
override def copyIn(using Context): ExplainingTypeComparer = new ExplainingTypeComparer
26872685

26882686
override def lastTrace(): String = "Subtype trace:" + { try b.toString finally b.clear() }
26892687
}

compiler/src/dotty/tools/dotc/core/Types.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -4283,7 +4283,7 @@ object Types {
42834283
override def tryNormalize(using Context): Type = reduced.normalized
42844284

42854285
def reduced(using Context): Type = {
4286-
val trackingCtx = ctx.fresh.setTypeComparerFn(new TrackingTypeComparer(_))
4286+
val trackingCtx = ctx.fresh.setTypeComparerFn(new TrackingTypeComparer(using _))
42874287
val typeComparer = trackingCtx.typeComparer.asInstanceOf[TrackingTypeComparer]
42884288

42894289
def contextInfo(tp: Type): Type = tp match {

0 commit comments

Comments
 (0)