Skip to content

Commit 001817c

Browse files
authored
Merge pull request #9322 from dotty-staging/fix-gadt-approximation
2 parents 5cdfd31 + c434949 commit 001817c

File tree

7 files changed

+188
-55
lines changed

7 files changed

+188
-55
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
751751
case _ => false
752752
}
753753
case _ => false
754-
comparePaths || isNewSubType(tp1.underlying.widenExpr)
754+
comparePaths || isSubType(tp1.underlying.widenExpr, tp2, approx.addLow)
755755
case tp1: RefinedType =>
756756
isNewSubType(tp1.parent)
757757
case tp1: RecType =>

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+36-26
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import NameKinds.UniqueName
1111
import util.Spans._
1212
import util.{Stats, SimpleIdentityMap}
1313
import Decorators._
14-
import config.Printers.{gadts, typr}
14+
import config.Printers.{gadts, typr, debug}
1515
import annotation.tailrec
1616
import reporting._
1717
import collection.mutable
@@ -171,37 +171,47 @@ object Inferencing {
171171
res
172172
}
173173

174-
/** This class is mostly based on IsFullyDefinedAccumulator.
175-
* It tries to approximate the given type based on the available GADT constraints.
176-
*/
174+
/** Approximates a type to get rid of as many GADT-constrained abstract types as possible. */
177175
private class ApproximateGadtAccumulator(using Context) extends TypeMap {
178176

179177
var failed = false
180178

181-
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
182-
val inst = tvar.instantiate(fromBelow)
183-
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
184-
inst
185-
}
186-
187-
private def instDirection2(sym: Symbol)(using Context): Int = {
188-
val constrained = ctx.gadt.fullBounds(sym)
189-
val original = sym.info.bounds
190-
val cmp = ctx.typeComparer
191-
val approxBelow =
192-
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
193-
val approxAbove =
194-
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
195-
approxAbove - approxBelow
196-
}
197-
198-
private[this] var toMaximize: Boolean = false
199-
179+
/** GADT approximation proceeds differently from type variable approximation.
180+
*
181+
* Essentially, what we're doing is we're inferring a type ascription that
182+
* will remove as many GADT-constrained types as possible. This means that
183+
* we want to approximate type T to type S in such a way that no matter how
184+
* GADT-constrained types are instantiated, T <: S. In other words, the
185+
* relationship _necessarily_ must hold.
186+
*
187+
* We accomplish that by:
188+
* - replacing covariant occurences with upper GADT bound
189+
* - replacing contravariant occurences with lower GADT bound
190+
* - leaving invariant occurences alone
191+
*
192+
* Examples:
193+
* - If we have GADT cstr A <: Int, then for all A <: Int, Option[A] <: Option[Int].
194+
* Therefore, we can approximate Option[A] ~~ Option[Int].
195+
* - If we have A >: S <: T, then for all such A, A => A <: S => T. This
196+
* illustrates that it's fine to differently approximate different
197+
* occurences of same type.
198+
* - If we have A <: Int and F <: [A] => Option[A] (note the invariance),
199+
* then we should approximate F[A] ~~ Option[A]. That is, we should
200+
* respect the invariance of the type constructor.
201+
* - If we have A <: Option[B] and B <: Int, we approximate A ~~
202+
* Option[B]. That is, we don't recurse into already approximated
203+
* types. Since GADT approximation is (for now) only used for member
204+
* selection, this behaviour is expected, as nested types cannot affect
205+
* member selection (note that given/extension lookup doesn't need GADT
206+
* approx, see gadt-approximation-interaction.scala).
207+
*/
200208
def apply(tp: Type): Type = tp.dealias match {
201-
case tp @ TypeRef(qual, nme) if (qual eq NoPrefix) && ctx.gadt.contains(tp.symbol) =>
209+
case tp @ TypeRef(qual, nme) if (qual eq NoPrefix)
210+
&& variance != 0
211+
&& ctx.gadt.contains(tp.symbol)
212+
=>
202213
val sym = tp.symbol
203-
val res =
204-
ctx.gadt.approximation(sym, fromBelow = variance < 0)
214+
val res = ctx.gadt.approximation(sym, fromBelow = variance < 0)
205215
gadts.println(i"approximated $tp ~~ $res")
206216
res
207217

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

+1-7
Original file line numberDiff line numberDiff line change
@@ -687,14 +687,8 @@ object ProtoTypes {
687687

688688
/** Dummy tree to be used as an argument of a FunProto or ViewProto type */
689689
object dummyTreeOfType {
690-
/*
691-
* A property indicating that the given tree was created with dummyTreeOfType.
692-
* It is sometimes necessary to detect the dummy trees to avoid unwanted readaptations on them.
693-
*/
694-
val IsDummyTree = new Property.Key[Unit]
695-
696690
def apply(tp: Type)(implicit src: SourceFile): Tree =
697-
(untpd.Literal(Constant(null)) withTypeUnchecked tp).withAttachment(IsDummyTree, ())
691+
untpd.Literal(Constant(null)) withTypeUnchecked tp
698692
def unapply(tree: untpd.Tree): Option[Type] = tree match {
699693
case Literal(Constant(null)) => Some(tree.typeOpt)
700694
case _ => None

compiler/src/dotty/tools/dotc/typer/Typer.scala

+14-21
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import collection.mutable
3333
import annotation.tailrec
3434
import Implicits._
3535
import util.Stats.record
36-
import config.Printers.{gadts, typr}
36+
import config.Printers.{gadts, typr, debug}
3737
import config.Feature._
3838
import config.SourceVersion._
3939
import rewrites.Rewrites.patch
@@ -3407,25 +3407,18 @@ class Typer extends Namer
34073407
case _ =>
34083408
}
34093409

3410-
val approximation = Inferencing.approximateGADT(wtp)
3411-
gadts.println(
3412-
i"""GADT approximation {
3413-
approximation = $approximation
3414-
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
3415-
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
3416-
ctx.gadt = ${ctx.gadt.debugBoundsDescription}
3417-
pt.isMatchedBy = ${
3418-
if (pt.isInstanceOf[SelectionProto])
3419-
pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString
3420-
else
3421-
"<not a SelectionProto>"
3422-
}
3423-
}
3424-
"""
3425-
)
3410+
// try GADT approximation, but only if we're trying to select a member
3411+
// Member lookup cannot take GADTs into account b/c of cache, so we
3412+
// approximate types based on GADT constraints instead. For an example,
3413+
// see MemberHealing in gadt-approximation-interaction.scala.
34263414
pt match {
3427-
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) =>
3428-
return tpd.Typed(tree, TypeTree(approximation))
3415+
case pt: SelectionProto if ctx.gadt.nonEmpty =>
3416+
gadts.println(i"Trying to heal member selection by GADT-approximating $wtp")
3417+
val gadtApprox = Inferencing.approximateGADT(wtp)
3418+
gadts.println(i"GADT-approximated $wtp ~~ $gadtApprox")
3419+
if pt.isMatchedBy(gadtApprox) then
3420+
gadts.println(i"Member selection healed by GADT approximation")
3421+
return tpd.Typed(tree, TypeTree(gadtApprox))
34293422
case _ => ;
34303423
}
34313424

@@ -3459,6 +3452,7 @@ class Typer extends Namer
34593452
if (isFullyDefined(wtp, force = ForceDegree.all) &&
34603453
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
34613454
else err.typeMismatch(tree, pt, failure)
3455+
34623456
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
34633457
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
34643458
ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos)
@@ -3469,14 +3463,13 @@ class Typer extends Namer
34693463
checkImplicitConversionUseOK(found.symbol, tree.posd)
34703464
readapt(found)(using ctx.retractMode(Mode.ImplicitsEnabled))
34713465
case failure: SearchFailure =>
3472-
if (pt.isInstanceOf[ProtoType] && !failure.isAmbiguous) {
3466+
if (pt.isInstanceOf[ProtoType] && !failure.isAmbiguous) then
34733467
// don't report the failure but return the tree unchanged. This
34743468
// will cause a failure at the next level out, which usually gives
34753469
// a better error message. To compensate, store the encountered failure
34763470
// as an attachment, so that it can be reported later as an addendum.
34773471
rememberSearchFailure(tree, failure)
34783472
tree
3479-
}
34803473
else recover(failure.reason)
34813474
}
34823475
else recover(NoMatchingImplicits)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
object MemberHealing {
2+
enum SUB[-A, +B]:
3+
case Refl[S]() extends SUB[S, S]
4+
5+
def foo[T](t: T, ev: T SUB Int) =
6+
ev match { case SUB.Refl() =>
7+
t + 2
8+
}
9+
}
10+
11+
object ImplicitLookup {
12+
enum SUB[-A, +B]:
13+
case Refl[S]() extends SUB[S, S]
14+
15+
class Tag[T]
16+
17+
implicit val ti: Tag[Int] = Tag()
18+
19+
def foo[T](t: T, ev: T SUB Int) =
20+
ev match { case SUB.Refl() =>
21+
implicitly[Tag[Int]]
22+
}
23+
}
24+
25+
object GivenLookup {
26+
enum SUB[-A, +B]:
27+
case Refl[S]() extends SUB[S, S]
28+
29+
class Tag[T]
30+
31+
given ti as Tag[Int]
32+
33+
def foo[T](t: T, ev: T SUB Int) =
34+
ev match { case SUB.Refl() =>
35+
summon[Tag[Int]]
36+
}
37+
}
38+
39+
object ImplicitConversion {
40+
enum SUB[-A, +B]:
41+
case Refl[S]() extends SUB[S, S]
42+
43+
class Pow(self: Int):
44+
def **(other: Int): Int = math.pow(self, other).toInt
45+
46+
implicit def pow(i: Int): Pow = Pow(i)
47+
48+
def foo[T](t: T, ev: T SUB Int) =
49+
ev match { case SUB.Refl() =>
50+
t ** 2 // error // implementation limitation
51+
}
52+
53+
def bar[T](t: T, ev: T SUB Int) =
54+
ev match { case SUB.Refl() =>
55+
(t: Int) ** 2 // sanity check
56+
}
57+
}
58+
59+
object GivenConversion {
60+
enum SUB[-A, +B]:
61+
case Refl[S]() extends SUB[S, S]
62+
63+
class Pow(self: Int):
64+
def **(other: Int): Int = math.pow(self, other).toInt
65+
66+
given as Conversion[Int, Pow] = (i: Int) => Pow(i)
67+
68+
def foo[T](t: T, ev: T SUB Int) =
69+
ev match { case SUB.Refl() =>
70+
t ** 2 // error (implementation limitation)
71+
}
72+
73+
def bar[T](t: T, ev: T SUB Int) =
74+
ev match { case SUB.Refl() =>
75+
(t: Int) ** 2 // sanity check
76+
}
77+
}
78+
79+
object ExtensionMethod {
80+
enum SUB[-A, +B]:
81+
case Refl[S]() extends SUB[S, S]
82+
83+
extension (x: Int):
84+
def **(y: Int) = math.pow(x, y).toInt
85+
86+
def foo[T](t: T, ev: T SUB Int) =
87+
ev match { case SUB.Refl() =>
88+
t ** 2
89+
}
90+
}
91+
92+
object HKFun {
93+
enum SUB[-A, +B]:
94+
case Refl[S]() extends SUB[S, S]
95+
96+
enum HKSUB[-F[_], +G[_]]:
97+
case Refl[H[_]]() extends HKSUB[H, H]
98+
99+
def foo[F[_], T](ft: F[T], hkev: F HKSUB Option, ev: T SUB Int) =
100+
hkev match { case HKSUB.Refl() =>
101+
ev match { case SUB.Refl() =>
102+
// both should typecheck - we should respect invariance of F
103+
// (and not approximate its argument)
104+
// but also T <: Int b/c of ev
105+
val x: T = ft.get
106+
val y: Int = ft.get
107+
}
108+
}
109+
110+
enum COVHKSUB[-F[+_], +G[+_]]:
111+
case Refl[H[_]]() extends COVHKSUB[H, H]
112+
113+
def bar[F[+_], T](ft: F[T], hkev: F COVHKSUB Option, ev: T SUB Int) =
114+
hkev match { case COVHKSUB.Refl() =>
115+
ev match { case SUB.Refl() =>
116+
// Sanity check for `foo`
117+
// this is an error only because we blindly approximate covariant type arguments
118+
// if it stops being an error, `foo` should be re-thought
119+
val x: T = ft.get // error
120+
val y: Int = ft.get
121+
}
122+
}
123+
}
124+
125+
object NestedConstrained {
126+
enum SUB[-A, +B]:
127+
case Refl[S]() extends SUB[S, S]
128+
129+
def foo[A, B](a: A, ev1: A SUB Option[B], ev2: B SUB Int) =
130+
ev1 match { case SUB.Refl() =>
131+
ev2 match { case SUB.Refl() =>
132+
1 + "a"
133+
a.get : Int
134+
}
135+
}
136+
}
File renamed without changes.

0 commit comments

Comments
 (0)