Skip to content

Commit 42b622c

Browse files
committed
Take HKT injectivity into account when inferring GADT constraints
When going into arguments of a HKT, we only infer GADT constraints if we are certain that tycon is injective.
1 parent e09cd83 commit 42b622c

File tree

5 files changed

+70
-38
lines changed

5 files changed

+70
-38
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+21-6
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
141141
*/
142142
private [this] var leftRoot: Type = _
143143

144+
/** Are we forbidden from recording GADT constraints? */
145+
private[this] var frozenGadt = false
146+
144147
protected def isSubType(tp1: Type, tp2: Type, a: ApproxState): Boolean = {
145148
val savedApprox = approx
146149
val savedLeftRoot = leftRoot
@@ -840,8 +843,11 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
840843
gadtBoundsContain(tycon1sym, tycon2) ||
841844
gadtBoundsContain(tycon2sym, tycon1)
842845
) &&
843-
isSubType(tycon1.prefix, tycon2.prefix) &&
844-
isSubArgs(args1, args2, tp1, tparams)
846+
isSubType(tycon1.prefix, tycon2.prefix) && {
847+
// check both tycons to deal with the case when they are equal b/c of GADT constraint
848+
val tyconIsInjective = tycon1sym.isClass || tycon2sym.isClass
849+
isSubArgs(args1, args2, tp1, tparams, inferGadtBounds = tyconIsInjective)
850+
}
845851
if (res && touchedGADTs) GADTused = true
846852
res
847853
case _ =>
@@ -1097,7 +1103,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
10971103
* @param tp1 The applied type containing `args1`
10981104
* @param tparams2 The type parameters of the type constructor applied to `args2`
10991105
*/
1100-
def isSubArgs(args1: List[Type], args2: List[Type], tp1: Type, tparams2: List[ParamInfo]): Boolean = {
1106+
def isSubArgs(args1: List[Type], args2: List[Type], tp1: Type, tparams2: List[ParamInfo], inferGadtBounds: Boolean = false): Boolean = {
11011107
/** The bounds of parameter `tparam`, where all references to type paramneters
11021108
* are replaced by corresponding arguments (or their approximations in the case of
11031109
* wildcard arguments).
@@ -1161,8 +1167,17 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
11611167
case arg1: TypeBounds =>
11621168
compareCaptured(arg1, arg2)
11631169
case _ =>
1164-
(v > 0 || isSubType(arg2, arg1)) &&
1165-
(v < 0 || isSubType(arg1, arg2))
1170+
def isSub(tp: Type, pt: Type): Boolean = {
1171+
if (inferGadtBounds) isSubType(tp, pt)
1172+
else {
1173+
val savedFrozenGadt = frozenGadt
1174+
frozenGadt = true
1175+
try isSubType(tp, pt) finally frozenGadt = savedFrozenGadt
1176+
}
1177+
}
1178+
1179+
(v > 0 || isSub(arg2, arg1)) &&
1180+
(v < 0 || isSub(arg1, arg2))
11661181
}
11671182
}
11681183

@@ -1476,7 +1491,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
14761491
*/
14771492
private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = {
14781493
val boundImprecise = approx.high || approx.low
1479-
ctx.mode.is(Mode.GADTflexible) && !frozenConstraint && !boundImprecise && {
1494+
ctx.mode.is(Mode.GADTflexible) && !frozenGadt && !frozenConstraint && !boundImprecise && {
14801495
val tparam = tr.symbol
14811496
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
14821497
if (bound.isRef(tparam)) false
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
object Test {
2+
enum EQ[A, B] {
3+
case Refl[T]() extends EQ[T, T]
4+
}
5+
import EQ._
6+
7+
object A {
8+
type Foo[+X] = (X, X)
9+
def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match {
10+
case Refl() => x
11+
}
12+
}
13+
14+
object B {
15+
type Foo[X] = (X, X)
16+
def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match {
17+
case Refl() => x
18+
}
19+
}
20+
21+
object C {
22+
type Foo[+X] = Int | (X, X)
23+
def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match {
24+
case Refl() => x
25+
}
26+
}
27+
28+
object D {
29+
type Foo[+X] = (Int, Int)
30+
def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match {
31+
case Refl() => x // error
32+
}
33+
}
34+
35+
trait E {
36+
type Foo[+X] <: Int | (X, X)
37+
def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match {
38+
case Refl() => x // error
39+
}
40+
}
41+
42+
trait F {
43+
type Foo[X] >: Int | (X, X)
44+
def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match {
45+
case Refl() => x // error
46+
}
47+
}
48+
}

tests/neg/gadt-uninjectivity.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ object uninjectivity {
44

55
def absurd1[F[_], X, Y](eq: EQ[F[X], F[Y]], x: X): Y = eq match {
66
case Refl() =>
7-
x // should be an error
7+
x // error
88
}
99

1010
def absurd2[F[_], G[_]](eq: EQ[F[Int], G[Int]], fi: F[Int], fs: F[String]): G[Int] = eq match {

tests/pos/gadt-EQK.scala

-12
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,4 @@ object EQK {
1818
fa : G[Int]
1919
}
2020
}
21-
22-
def m2[F[_], G[_], A](fa: F[A], a: A, eq: EQ[F[A], G[Int]], eqk: EQK[F, G]): Int =
23-
eqk match {
24-
case ReflK() => eq match {
25-
case Refl() =>
26-
val r1: F[Int] = fa
27-
val r2: G[A] = fa
28-
val r3: F[Int] = r2
29-
a
30-
}
31-
}
32-
3321
}

tests/run/gadt-injectivity-unsoundness.scala

-19
This file was deleted.

0 commit comments

Comments
 (0)