Skip to content

Commit c8ddbc0

Browse files
authored
Merge pull request #15646 from Linyxus/fix-if-union-b
Fix GADT casting when typing if expressions
2 parents 7d43b44 + 2255fdb commit c8ddbc0

File tree

5 files changed

+110
-15
lines changed

5 files changed

+110
-15
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

+54-15
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
11351135
case elsep: untpd.If => isIncomplete(elsep)
11361136
case _ => false
11371137

1138+
// Insert a GADT cast if the type of the branch does not conform
1139+
// to the type assigned to the whole if tree.
1140+
// This happens when the computation of the type of the if tree
1141+
// uses GADT constraints. See #15646.
1142+
def gadtAdaptBranch(tree: Tree, branchPt: Type): Tree =
1143+
TypeComparer.testSubType(tree.tpe.widenExpr, branchPt) match {
1144+
case CompareResult.OKwithGADTUsed =>
1145+
insertGadtCast(tree, tree.tpe.widen, branchPt)
1146+
case _ => tree
1147+
}
1148+
11381149
val branchPt = if isIncomplete(tree) then defn.UnitType else pt.dropIfProto
11391150

11401151
val result =
@@ -1148,7 +1159,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
11481159
val elsep0 = typed(tree.elsep, branchPt)(using cond1.nullableContextIf(false))
11491160
thenp0 :: elsep0 :: Nil
11501161
}: @unchecked
1151-
assignType(cpy.If(tree)(cond1, thenp1, elsep1), thenp1, elsep1)
1162+
1163+
val resType = thenp1.tpe | elsep1.tpe
1164+
val thenp2 :: elsep2 :: Nil =
1165+
(thenp1 :: elsep1 :: Nil) map { t =>
1166+
// Adapt each branch to ensure that their types conforms to the
1167+
// type assigned to the if tree by inserting GADT casts.
1168+
gadtAdaptBranch(t, resType)
1169+
}: @unchecked
1170+
1171+
cpy.If(tree)(cond1, thenp2, elsep2).withType(resType)
11521172

11531173
def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
11541174
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
@@ -3763,20 +3783,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
37633783
gadts.println(i"unnecessary GADTused for $tree: ${tree.tpe.widenExpr} vs $pt in ${ctx.source}")
37643784
res
37653785
} =>
3766-
// Insert an explicit cast, so that -Ycheck in later phases succeeds.
3767-
// The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts.
3768-
val target =
3769-
if tree.tpe.isSingleton then
3770-
val conj = AndType(tree.tpe, pt)
3771-
if tree.tpe.isStable && !conj.isStable then
3772-
// this is needed for -Ycheck. Without the annotation Ycheck will
3773-
// skolemize the result type which will lead to different types before
3774-
// and after checking. See i11955.scala.
3775-
AnnotatedType(conj, Annotation(defn.UncheckedStableAnnot))
3776-
else conj
3777-
else pt
3778-
gadts.println(i"insert GADT cast from $tree to $target")
3779-
tree.cast(target)
3786+
insertGadtCast(tree, wtp, pt)
37803787
case _ =>
37813788
//typr.println(i"OK ${tree.tpe}\n${TypeComparer.explained(_.isSubType(tree.tpe, pt))}") // uncomment for unexpected successes
37823789
tree
@@ -4207,4 +4214,36 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
42074214
EmptyTree
42084215
else typedExpr(call, defn.AnyType)
42094216

4217+
/** Insert GADT cast to target type `pt` on the `tree`
4218+
* so that -Ycheck in later phases succeeds.
4219+
* The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts.
4220+
*/
4221+
private def insertGadtCast(tree: Tree, wtp: Type, pt: Type)(using Context): Tree =
4222+
val target =
4223+
if tree.tpe.isSingleton then
4224+
// In the target type, when the singleton type is intersected, we also intersect
4225+
// the GADT-approximated type of the singleton to avoid the loss of
4226+
// information. See #15646.
4227+
val gadtApprox = Inferencing.approximateGADT(wtp)
4228+
gadts.println(i"gadt approx $wtp ~~~ $gadtApprox")
4229+
val conj =
4230+
TypeComparer.testSubType(gadtApprox, pt) match {
4231+
case CompareResult.OK =>
4232+
// GADT approximation of the tree type is a subtype of expected type under empty GADT
4233+
// constraints, so it is enough to only have the GADT approximation.
4234+
AndType(tree.tpe, gadtApprox)
4235+
case _ =>
4236+
// In other cases, we intersect both the approximated type and the expected type.
4237+
AndType(AndType(tree.tpe, gadtApprox), pt)
4238+
}
4239+
if tree.tpe.isStable && !conj.isStable then
4240+
// this is needed for -Ycheck. Without the annotation Ycheck will
4241+
// skolemize the result type which will lead to different types before
4242+
// and after checking. See i11955.scala.
4243+
AnnotatedType(conj, Annotation(defn.UncheckedStableAnnot))
4244+
else conj
4245+
else pt
4246+
gadts.println(i"insert GADT cast from $tree to $target")
4247+
tree.cast(target)
4248+
end insertGadtCast
42104249
}

tests/pos/gadt-cast-if.scala

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
trait Expr[T]
2+
case class IntExpr() extends Expr[Int]
3+
4+
def flag: Boolean = ???
5+
6+
def foo[T](ev: Expr[T]): Int | T = ev match
7+
case IntExpr() =>
8+
if flag then
9+
val i: T = ???
10+
i
11+
else
12+
(??? : Int)

tests/pos/gadt-cast-singleton.scala

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
enum SUB[-A, +B]:
2+
case Refl[S]() extends SUB[S, S]
3+
4+
trait R {
5+
type Data
6+
}
7+
trait L extends R
8+
9+
def f(x: L): x.Data = ???
10+
11+
def g[T <: R](x: T, ev: T SUB L): x.Data = ev match
12+
case SUB.Refl() =>
13+
f(x)

tests/pos/i14776-patmat.scala

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
trait T1
2+
trait T2 extends T1
3+
4+
trait Expr[T] { val data: T = ??? }
5+
case class Tag2() extends Expr[T2]
6+
7+
def flag: Boolean = ???
8+
9+
def foo[T](e: Expr[T]): T1 = e match {
10+
case Tag2() =>
11+
flag match
12+
case true => new T2 {}
13+
case false => e.data
14+
}
15+

tests/pos/i14776.scala

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
trait T1
2+
trait T2 extends T1
3+
4+
trait Expr[T] { val data: T = ??? }
5+
case class Tag2() extends Expr[T2]
6+
7+
def flag: Boolean = ???
8+
9+
def foo[T](e: Expr[T]): T1 = e match {
10+
case Tag2() =>
11+
if flag then
12+
new T2 {}
13+
else
14+
e.data
15+
}
16+

0 commit comments

Comments
 (0)