Skip to content

Commit b4338a8

Browse files
authored
Merge pull request #8867 from dotty-staging/fix-#8861
Fix #8861: Avoid parameters when instantiating closure results
2 parents 1896c2b + e97b278 commit b4338a8

File tree

8 files changed

+145
-30
lines changed

8 files changed

+145
-30
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ trait ConstraintHandling[AbstractContext] {
296296

297297
/** Widen inferred type `inst` with upper `bound`, according to the following rules:
298298
* 1. If `inst` is a singleton type, or a union containing some singleton types,
299-
* widen (all) the singleton type(s), provied the result is a subtype of `bound`
299+
* widen (all) the singleton type(s), provided the result is a subtype of `bound`
300300
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
301301
* 2. If `inst` is a union type, approximate the union type from above by an intersection
302-
* of all common base types, provied the result is a subtype of `bound`.
302+
* of all common base types, provided the result is a subtype of `bound`.
303303
*
304304
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
305305
* Also, if the result of these widenings is a TypeRef to a module class,
@@ -312,15 +312,17 @@ trait ConstraintHandling[AbstractContext] {
312312
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type = {
313313
def widenOr(tp: Type) = {
314314
val tpw = tp.widenUnion
315-
if ((tpw ne tp) && tpw <:< bound) tpw else tp
315+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
316316
}
317317
def widenSingle(tp: Type) = {
318318
val tpw = tp.widenSingletons
319-
if ((tpw ne tp) && tpw <:< bound) tpw else tp
319+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
320320
}
321+
def isSingleton(tp: Type): Boolean = tp match
322+
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
323+
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)
321324
val wideInst =
322-
if (isSubTypeWhenFrozen(bound, defn.SingletonType)) inst
323-
else widenOr(widenSingle(inst))
325+
if isSingleton(bound) then inst else widenOr(widenSingle(inst))
324326
wideInst match
325327
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
326328
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ final class ProperGadtConstraint private(
116116
)
117117

118118
val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) =>
119-
val tv = new TypeVar(paramRef, creatorState = null)
119+
val tv = TypeVar(paramRef, creatorState = null)
120120
mapping = mapping.updated(sym, tv)
121121
reverseMapping = reverseMapping.updated(tv.origin, sym)
122122
tv

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,14 @@ object SymDenotations {
14591459
else if is(Contravariant) then Contravariant
14601460
else EmptyFlags
14611461

1462+
/** The length of the owner chain of this symbol. 1 for _root_, 0 for NoSymbol */
1463+
def nestingLevel(using Context): Int =
1464+
@tailrec def recur(d: SymDenotation, n: Int): Int = d match
1465+
case NoDenotation => n
1466+
case d: ClassDenotation => d.nestingLevel + n // profit from the cache in ClassDenotation
1467+
case _ => recur(d.owner, n + 1)
1468+
recur(this, 0)
1469+
14621470
/** The flags to be used for a type parameter owned by this symbol.
14631471
* Overridden by ClassDenotation.
14641472
*/
@@ -2160,6 +2168,12 @@ object SymDenotations {
21602168

21612169
override def registeredCompanion(implicit ctx: Context) = { ensureCompleted(); myCompanion }
21622170
override def registeredCompanion_=(c: Symbol) = { myCompanion = c }
2171+
2172+
private var myNestingLevel = -1
2173+
2174+
override def nestingLevel(using Context) =
2175+
if myNestingLevel == -1 then myNestingLevel = owner.nestingLevel + 1
2176+
myNestingLevel
21632177
}
21642178

21652179
/** The denotation of a package class.

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4110,18 +4110,17 @@ object Types {
41104110
*
41114111
* @param origin The parameter that's tracked by the type variable.
41124112
* @param creatorState The typer state in which the variable was created.
4113-
*
4114-
* `owningTree` and `owner` are used to determine whether a type-variable can be instantiated
4115-
* at some given point. See `Inferencing#interpolateUndetVars`.
41164113
*/
4117-
final class TypeVar(private var _origin: TypeParamRef, creatorState: TyperState) extends CachedProxyType with ValueType {
4114+
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, nestingLevel: Int) extends CachedProxyType with ValueType {
4115+
4116+
private var currentOrigin = initOrigin
41184117

4119-
def origin: TypeParamRef = _origin
4118+
def origin: TypeParamRef = currentOrigin
41204119

41214120
/** Set origin to new parameter. Called if we merge two conflicting constraints.
41224121
* See OrderingConstraint#merge, OrderingConstraint#rename
41234122
*/
4124-
def setOrigin(p: TypeParamRef) = _origin = p
4123+
def setOrigin(p: TypeParamRef) = currentOrigin = p
41254124

41264125
/** The permanent instance type of the variable, or NoType is none is given yet */
41274126
private var myInst: Type = NoType
@@ -4150,6 +4149,36 @@ object Types {
41504149
/** Is the variable already instantiated? */
41514150
def isInstantiated(implicit ctx: Context): Boolean = instanceOpt.exists
41524151

4152+
/** Avoid term references in `tp` to parameters or local variables that
4153+
* are nested more deeply than the type variable itself.
4154+
*/
4155+
private def avoidCaptures(tp: Type)(using Context): Type =
4156+
val problemSyms = new TypeAccumulator[Set[Symbol]]:
4157+
def apply(syms: Set[Symbol], t: Type): Set[Symbol] = t match
4158+
case ref @ TermRef(NoPrefix, _)
4159+
// AVOIDANCE TODO: Are there other problematic kinds of references?
4160+
// Our current tests only give us these, but we might need to generalize this.
4161+
if ref.symbol.maybeOwner.nestingLevel > nestingLevel =>
4162+
syms + ref.symbol
4163+
case _ =>
4164+
foldOver(syms, t)
4165+
val problems = problemSyms(Set.empty, tp)
4166+
if problems.isEmpty then tp
4167+
else
4168+
val atp = ctx.typer.avoid(tp, problems.toList)
4169+
def msg = i"Inaccessible variables captured in instantation of type variable $this.\n$tp was fixed to $atp"
4170+
typr.println(msg)
4171+
val bound = ctx.typeComparer.fullUpperBound(origin)
4172+
if !(atp <:< bound) then
4173+
throw new TypeError(s"$msg,\nbut the latter type does not conform to the upper bound $bound")
4174+
atp
4175+
// AVOIDANCE TODO: This really works well only if variables are instantiated from below
4176+
// If we hit a problematic symbol while instantiating from above, then avoidance
4177+
// will widen the instance type further. This could yield an alias, which would be OK.
4178+
// But it also could yield a true super type which would then fail the bounds check
4179+
// and throw a TypeError. The right thing to do instead would be to avoid "downwards".
4180+
// To do this, we need first test cases for that situation.
4181+
41534182
/** Instantiate variable with given type */
41544183
def instantiateWith(tp: Type)(implicit ctx: Context): Type = {
41554184
assert(tp ne this, s"self instantiation of ${tp.show}, constraint = ${ctx.typerState.constraint.show}")
@@ -4168,7 +4197,7 @@ object Types {
41684197
* is also a singleton type.
41694198
*/
41704199
def instantiate(fromBelow: Boolean)(implicit ctx: Context): Type =
4171-
instantiateWith(ctx.typeComparer.instanceType(origin, fromBelow))
4200+
instantiateWith(avoidCaptures(ctx.typeComparer.instanceType(origin, fromBelow)))
41724201

41734202
/** For uninstantiated type variables: Is the lower bound different from Nothing? */
41744203
def hasLowerBound(implicit ctx: Context): Boolean =
@@ -4200,6 +4229,9 @@ object Types {
42004229
s"TypeVar($origin$instStr)"
42014230
}
42024231
}
4232+
object TypeVar:
4233+
def apply(initOrigin: TypeParamRef, creatorState: TyperState)(using Context) =
4234+
new TypeVar(initOrigin, creatorState, ctx.owner.nestingLevel)
42034235

42044236
type TypeVars = SimpleIdentitySet[TypeVar]
42054237

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,11 @@ class Namer { typer: Typer =>
196196

197197
import untpd._
198198

199-
val TypedAhead: Property.Key[tpd.Tree] = new Property.Key
200-
val ExpandedTree: Property.Key[untpd.Tree] = new Property.Key
199+
val TypedAhead : Property.Key[tpd.Tree] = new Property.Key
200+
val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key
201201
val ExportForwarders: Property.Key[List[tpd.MemberDef]] = new Property.Key
202-
val SymOfTree: Property.Key[Symbol] = new Property.Key
203-
val Deriver: Property.Key[typer.Deriver] = new Property.Key
202+
val SymOfTree : Property.Key[Symbol] = new Property.Key
203+
val Deriver : Property.Key[typer.Deriver] = new Property.Key
204204

205205
/** A partial map from unexpanded member and pattern defs and to their expansions.
206206
* Populated during enterSyms, emptied during typer.
@@ -1440,13 +1440,10 @@ class Namer { typer: Typer =>
14401440
// instead of widening to the underlying module class types.
14411441
// We also drop the @Repeated annotation here to avoid leaking it in method result types
14421442
// (see run/inferred-repeated-result).
1443-
def widenRhs(tp: Type): Type = {
1444-
val tp1 = tp.widenTermRefExpr.simplified match
1443+
def widenRhs(tp: Type): Type =
1444+
tp.widenTermRefExpr.simplified match
14451445
case ctp: ConstantType if isInlineVal => ctp
1446-
case ref: TypeRef if ref.symbol.is(ModuleClass) => tp
1447-
case tp => tp.widenUnion
1448-
tp1.dropRepeatedAnnot
1449-
}
1446+
case tp => ctx.typeComparer.widenInferred(tp, rhsProto)
14501447

14511448
// Replace aliases to Unit by Unit itself. If we leave the alias in
14521449
// it would be erased to BoxedUnit.
@@ -1498,9 +1495,21 @@ class Namer { typer: Typer =>
14981495
if (isFullyDefined(tpe, ForceDegree.none)) tpe
14991496
else typedAheadExpr(mdef.rhs, tpe).tpe
15001497
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
1501-
val rhsType = typedAheadExpr(mdef.rhs, tpt.tpe).tpe
15021498
mdef match {
15031499
case mdef: DefDef if mdef.name == nme.ANON_FUN =>
1500+
// This case applies if the closure result type contains uninstantiated
1501+
// type variables. In this case, constrain the closure result from below
1502+
// by the parameter-capture-avoiding type of the body.
1503+
val rhsType = typedAheadExpr(mdef.rhs, tpt.tpe).tpe
1504+
1505+
// The following part is important since otherwise we might instantiate
1506+
// the closure result type with a plain functon type that refers
1507+
// to local parameters. An example where this happens in `dependent-closures.scala`
1508+
// If the code after `val rhsType` is commented out, this file fails pickling tests.
1509+
// AVOIDANCE TODO: Follow up why this happens, and whether there
1510+
// are better ways to achieve this. It would be good if we could get rid of this code.
1511+
// It seems at least partially redundant with the nesting level checking on TypeVar
1512+
// instantiation.
15041513
val hygienicType = avoid(rhsType, paramss.flatten)
15051514
if (!hygienicType.isValueType || !(hygienicType <:< tpt.tpe))
15061515
ctx.error(i"return type ${tpt.tpe} of lambda cannot be made hygienic;\n" +
@@ -1513,10 +1522,10 @@ class Namer { typer: Typer =>
15131522
case _ =>
15141523
WildcardType
15151524
}
1516-
val memTpe = paramFn(checkSimpleKinded(typedAheadType(mdef.tpt, tptProto)).tpe)
1525+
val mbrTpe = paramFn(checkSimpleKinded(typedAheadType(mdef.tpt, tptProto)).tpe)
15171526
if (ctx.explicitNulls && mdef.mods.is(JavaDefined))
1518-
JavaNullInterop.nullifyMember(sym, memTpe, mdef.mods.isAllOf(JavaEnumValue))
1519-
else memTpe
1527+
JavaNullInterop.nullifyMember(sym, mbrTpe, mdef.mods.isAllOf(JavaEnumValue))
1528+
else mbrTpe
15201529
}
15211530

15221531
/** The type signature of a DefDef with given symbol */

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,8 @@ object ProtoTypes {
501501
def newTypeVars(tl: TypeLambda): List[TypeTree] =
502502
for (paramRef <- tl.paramRefs)
503503
yield {
504-
val tt = new TypeVarBinder().withSpan(owningTree.span)
505-
val tvar = new TypeVar(paramRef, state)
504+
val tt = TypeVarBinder().withSpan(owningTree.span)
505+
val tvar = TypeVar(paramRef, state)
506506
state.ownedVars += tvar
507507
tt.withType(tvar)
508508
}

tests/neg/i8861.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
object Test {
2+
sealed trait Container { s =>
3+
type A
4+
def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R
5+
}
6+
final class IntV extends Container { s =>
7+
type A = Int
8+
val i: Int = 42
9+
def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R = int(this)
10+
}
11+
final class StrV extends Container { s =>
12+
type A = String
13+
val t: String = "hello"
14+
def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R = str(this)
15+
}
16+
17+
def minimalOk[R](c: Container { type A = R }): R = c.visit[R](
18+
int = vi => vi.i : vi.A,
19+
str = vs => vs.t : vs.A
20+
)
21+
def minimalFail[M](c: Container { type A = M }): M = c.visit(
22+
int = vi => vi.i : vi.A,
23+
str = vs => vs.t : vs.A // error
24+
)
25+
26+
def main(args: Array[String]): Unit = {
27+
val e: Container { type A = String } = new StrV
28+
println(minimalOk(e)) // this one prints "hello"
29+
println(minimalFail(e)) // this one fails with ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer
30+
}
31+
}

tests/pos/dependent-closures.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
trait S { type N; def n: N }
2+
3+
def newS[X](n: X): S { type N = X } = ???
4+
5+
def test =
6+
val ss: List[S] = ???
7+
val cl1 = (s: S) => newS(s.n)
8+
val cl2: (s: S) => S { type N = s.N } = cl1
9+
def f[R](cl: (s: S) => R) = cl
10+
val x = f(s => newS(s.n))
11+
val x1: (s: S) => S = x
12+
// If the code in `tptProto` of Namer that refers to this
13+
// file is commented out, we see:
14+
// pickling difference for the result type of the closure argument
15+
// before pickling: S => S { type N = s.N }
16+
// after pickling : (s: S) => S { type N = s.N }
17+
18+
ss.map(s => newS(s.n))
19+
// If the code in `tptProto` of Namer that refers to this
20+
// file is commented out, we see a pickling difference like the one above.
21+
22+
def g[R](cl: (s: S) => (S { type N = s.N }, R)) = ???
23+
g(s => (newS(s.n), identity(1)))
24+
25+
def h(cl: (s: S) => S { type N = s.N }) = ???
26+
h(s => newS(s.n))
27+

0 commit comments

Comments
 (0)