Skip to content

Commit a83a4e0

Browse files
authored
[bp] Add GADT symbols when typing typing-ahead lambda bodies (#19771)
Backport of #19644
2 parents 59085f1 + e6359f5 commit a83a4e0

File tree

4 files changed

+75
-9
lines changed

4 files changed

+75
-9
lines changed

compiler/src/dotty/tools/dotc/typer/Namer.scala

+14-9
Original file line numberDiff line numberDiff line change
@@ -1734,8 +1734,9 @@ class Namer { typer: Typer =>
17341734
val tpe = (paramss: @unchecked) match
17351735
case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams)
17361736
case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams)
1737+
val rhsCtx = prepareRhsCtx(ctx.fresh, paramss)
17371738
if (isFullyDefined(tpe, ForceDegree.none)) tpe
1738-
else typedAheadExpr(mdef.rhs, tpe).tpe
1739+
else typedAheadExpr(mdef.rhs, tpe)(using rhsCtx).tpe
17391740

17401741
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
17411742
mdef match {
@@ -1933,14 +1934,7 @@ class Namer { typer: Typer =>
19331934
var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType)
19341935
if sym.isInlineMethod then rhsCtx = rhsCtx.addMode(Mode.InlineableBody)
19351936
if sym.is(ExtensionMethod) then rhsCtx = rhsCtx.addMode(Mode.InExtensionMethod)
1936-
val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten
1937-
if (typeParams.nonEmpty) {
1938-
// we'll be typing an expression from a polymorphic definition's body,
1939-
// so we must allow constraining its type parameters
1940-
// compare with typedDefDef, see tests/pos/gadt-inference.scala
1941-
rhsCtx.setFreshGADTBounds
1942-
rhsCtx.gadtState.addToConstraint(typeParams)
1943-
}
1937+
rhsCtx = prepareRhsCtx(rhsCtx, paramss)
19441938

19451939
def typedAheadRhs(pt: Type) =
19461940
PrepareInlineable.dropInlineIfError(sym,
@@ -1985,4 +1979,15 @@ class Namer { typer: Typer =>
19851979
lhsType orElse WildcardType
19861980
}
19871981
end inferredResultType
1982+
1983+
/** Prepare a GADT-aware context used to type the RHS of a ValOrDefDef. */
1984+
def prepareRhsCtx(rhsCtx: FreshContext, paramss: List[List[Symbol]])(using Context): FreshContext =
1985+
val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten
1986+
if typeParams.nonEmpty then
1987+
// we'll be typing an expression from a polymorphic definition's body,
1988+
// so we must allow constraining its type parameters
1989+
// compare with typedDefDef, see tests/pos/gadt-inference.scala
1990+
rhsCtx.setFreshGADTBounds
1991+
rhsCtx.gadtState.addToConstraint(typeParams)
1992+
rhsCtx
19881993
}

tests/pos/i19570.min1.scala

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
enum Op[A]:
2+
case Dup[T]() extends Op[(T, T)]
3+
4+
def foo[R](f: [A] => Op[A] => R): R = ???
5+
6+
def test =
7+
foo([A] => (o: Op[A]) => o match
8+
case o: Op.Dup[u] =>
9+
summon[A =:= (u, u)] // Error: Cannot prove that A =:= (u, u)
10+
()
11+
)
12+
foo[Unit]([A] => (o: Op[A]) => o match
13+
case o: Op.Dup[u] =>
14+
summon[A =:= (u, u)] // Ok
15+
()
16+
)
17+
foo({
18+
val f1 = [B] => (o: Op[B]) => o match
19+
case o: Op.Dup[u] =>
20+
summon[B =:= (u, u)] // Also ok
21+
()
22+
f1
23+
})

tests/pos/i19570.min2.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
sealed trait Op[A, B] { def giveA: A; def giveB: B }
2+
final case class Dup[T](x: T) extends Op[T, (T, T)] { def giveA: T = x; def giveB: (T, T) = (x, x) }
3+
4+
class Test:
5+
def foo[R](f: [A, B] => (o: Op[A, B]) => R): R = ???
6+
7+
def m1: Unit =
8+
foo([A, B] => (o: Op[A, B]) => o match
9+
case o: Dup[t] =>
10+
var a1: t = o.giveA
11+
var a2: A = o.giveA
12+
a1 = a2
13+
a2 = a1
14+
15+
var b1: (t, t) = o.giveB
16+
var b2: B = o.giveB
17+
b1 = b2
18+
b2 = b1
19+
20+
summon[A =:= t] // ERROR: Cannot prove that A =:= t.
21+
summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t).
22+
23+
()
24+
)

tests/pos/i19570.orig.scala

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
enum Op[A, B]:
2+
case Dup[T]() extends Op[T, (T, T)]
3+
4+
def foo[R](f: [A, B] => (o: Op[A, B]) => R): R =
5+
f(Op.Dup())
6+
7+
def test =
8+
foo([A, B] => (o: Op[A, B]) => {
9+
o match
10+
case o: Op.Dup[t] =>
11+
summon[A =:= t] // ERROR: Cannot prove that A =:= t.
12+
summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t).
13+
42
14+
})

0 commit comments

Comments
 (0)