Skip to content

Commit b4e037c

Browse files
authored
Merge pull request #8232 from dotty-staging/fix-#8111
Fix #8111: Use better algorithms to infer parameter types
2 parents 5374d91 + 9f10d77 commit b4e037c

File tree

6 files changed

+78
-41
lines changed

6 files changed

+78
-41
lines changed

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import typer.Applications._
1919
import typer.ProtoTypes._
2020
import typer.ForceDegree
2121
import typer.Inferencing.isFullyDefined
22+
import typer.IfBottom
2223

2324
import scala.annotation.internal.sharable
2425

@@ -644,7 +645,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object.
644645
tvar =>
645646
!(ctx.typerState.constraint.entry(tvar.origin) `eq` tvar.origin.underlying) ||
646647
(tvar `eq` removeThisType.prefixTVar),
647-
allowBottom = false
648+
IfBottom.flip
648649
)
649650

650651
// If parent contains a reference to an abstract type, then we should

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1865,7 +1865,7 @@ trait Applications extends Compatibility {
18651865
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)
18661866
else defn.FunctionOf(commonParamTypes, WildcardType)
18671867
overload.println(i"pretype arg $arg with expected type $commonFormal")
1868-
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.noBottom)))
1868+
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom)))
18691869
pt.typedArg(arg, commonFormal)(ctx.addMode(Mode.ImplicitsEnabled))
18701870
}
18711871
case None =>

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ object Inferencing {
5555
def instantiateSelected(tp: Type, tvars: List[Type])(implicit ctx: Context): Unit =
5656
if (tvars.nonEmpty)
5757
IsFullyDefinedAccumulator(
58-
ForceDegree.Value(tvars.contains, allowBottom = false), minimizeSelected = true
58+
ForceDegree.Value(tvars.contains, IfBottom.flip), minimizeSelected = true
5959
).process(tp)
6060

6161
/** Instantiate any type variables in `tp` whose bounds contain a reference to
@@ -98,7 +98,7 @@ object Inferencing {
9898
9999
* If (1) and (2) do not apply, and minimizeSelected is not set:
100100
* 6: T is maximized if it appears only contravariantly in the given type,
101-
* or if forceDegree is `noBottom` and T has no lower bound different from Nothing.
101+
* or if forceDegree is `flipBottom` and T has no lower bound different from Nothing.
102102
* 7. Otherwise, T is minimized.
103103
*
104104
* The instantiation for (6) and (7) is done in two phases:
@@ -132,8 +132,10 @@ object Inferencing {
132132
if tvar.hasLowerBound then instantiate(tvar, fromBelow = true)
133133
else if tvar.hasUpperBound then instantiate(tvar, fromBelow = false)
134134
else () // hold off instantiating unbounded unconstrained variables
135-
else if variance >= 0 && (force.allowBottom || tvar.hasLowerBound) then
135+
else if variance >= 0 && (force.ifBottom == IfBottom.ok || tvar.hasLowerBound) then
136136
instantiate(tvar, fromBelow = true)
137+
else if variance >= 0 && force.ifBottom == IfBottom.fail then
138+
return false
137139
else
138140
toMaximize = tvar :: toMaximize
139141
foldOver(x, tvar)
@@ -150,9 +152,14 @@ object Inferencing {
150152
if !tvar.isInstantiated then
151153
instantiate(tvar, fromBelow = false)
152154
case nil =>
153-
val res = apply(true, tp)
154-
if res then maximize(toMaximize)
155-
res
155+
apply(true, tp)
156+
&& (
157+
toMaximize.isEmpty
158+
|| { maximize(toMaximize)
159+
toMaximize = Nil // Do another round since the maximixing instances
160+
process(tp) // might have type uninstantiated variables themselves.
161+
}
162+
)
156163
}
157164

158165
/** For all type parameters occurring in `tp`:
@@ -509,9 +516,13 @@ trait Inferencing { this: Typer =>
509516

510517
/** An enumeration controlling the degree of forcing in "is-dully-defined" checks. */
511518
@sharable object ForceDegree {
512-
class Value(val appliesTo: TypeVar => Boolean, val allowBottom: Boolean)
513-
val none: Value = new Value(_ => false, allowBottom = true)
514-
val all: Value = new Value(_ => true, allowBottom = true)
515-
val noBottom: Value = new Value(_ => true, allowBottom = false)
519+
class Value(val appliesTo: TypeVar => Boolean, val ifBottom: IfBottom)
520+
val none: Value = new Value(_ => false, IfBottom.ok)
521+
val all: Value = new Value(_ => true, IfBottom.ok)
522+
val failBottom: Value = new Value(_ => true, IfBottom.fail)
523+
val flipBottom: Value = new Value(_ => true, IfBottom.flip)
516524
}
517525

526+
enum IfBottom:
527+
case ok, fail, flip
528+

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
276276
// Make sure all type arguments to the call are fully determined,
277277
// but continue if that's not achievable (or else i7459.scala would crash).
278278
for arg <- callTypeArgs do
279-
isFullyDefined(arg.tpe, ForceDegree.noBottom)
279+
isFullyDefined(arg.tpe, ForceDegree.flipBottom)
280280

281281
/** A map from parameter names of the inlineable method to references of the actual arguments.
282282
* For a type argument this is the full argument type.

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ class Typer extends Namer
612612
var templ1 = templ
613613
def isEligible(tp: Type) = tp.exists && !tp.typeSymbol.is(Final) && !tp.isRef(defn.AnyClass)
614614
if (templ1.parents.isEmpty &&
615-
isFullyDefined(pt, ForceDegree.noBottom) &&
615+
isFullyDefined(pt, ForceDegree.flipBottom) &&
616616
isSkolemFree(pt) &&
617617
isEligible(pt.underlyingClassRef(refinementOK = false)))
618618
templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil)
@@ -1009,16 +1009,20 @@ class Typer extends Namer
10091009
yield param.name -> idx
10101010
}.toMap
10111011
if (paramIndex.size == params.length)
1012-
expr match {
1012+
expr match
10131013
case untpd.TypedSplice(expr1) =>
10141014
expr1.tpe
10151015
case _ =>
1016+
given nestedCtx as Context = ctx.fresh.setNewTyperState()
10161017
val protoArgs = args map (_ withType WildcardType)
10171018
val callProto = FunProto(protoArgs, WildcardType)(this, app.isGivenApply)
10181019
val expr1 = typedExpr(expr, callProto)
1019-
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
1020-
expr1.tpe
1021-
}
1020+
if nestedCtx.reporter.hasErrors then NoType
1021+
else
1022+
given Context = ctx
1023+
nestedCtx.typerState.commit()
1024+
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
1025+
expr1.tpe
10221026
else NoType
10231027
case _ =>
10241028
NoType
@@ -1030,42 +1034,53 @@ class Typer extends Namer
10301034
// try to instantiate `pt` if this is possible. If it does not
10311035
// work the error will be reported later in `inferredParam`,
10321036
// when we try to infer the parameter type.
1033-
isFullyDefined(pt, ForceDegree.noBottom)
1037+
isFullyDefined(pt, ForceDegree.flipBottom)
10341038
case _ =>
10351039
}
10361040

10371041
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
10381042

1039-
/** Two attempts: First, if expected type is fully defined pick this one.
1040-
* Second, if function is of the form
1041-
* (x1, ..., xN) => f(... x1, ..., XN, ...)
1042-
* where each `xi` occurs exactly once in the argument list of `f` (in
1043-
* any order), and f has a method type MT, pick the corresponding parameter
1044-
* type in MT, if this one is fully defined.
1045-
* If both attempts fail, issue a "missing parameter type" error.
1046-
*/
1047-
def inferredParamType(param: untpd.ValDef, formal: Type): Type = {
1048-
if (isFullyDefined(formal, ForceDegree.noBottom)) return formal
1049-
calleeType.widen match {
1043+
/** The inferred parameter type for a parameter in a lambda that does
1044+
* not have an explicit type given.
1045+
* An inferred parameter type I has two possible sources:
1046+
* - the type S known from the context
1047+
* - the "target type" T known from the callee `f` if the lambda is of a form like `x => f(x)`
1048+
* If `T` exists, we know that `S <: I <: T`.
1049+
*
1050+
* The inference makes three attempts:
1051+
*
1052+
* 1. If the expected type `S` is already fully defined under ForceDegree.failBottom
1053+
* pick this one.
1054+
* 2. Compute the target type `T` and make it known that `S <: T`.
1055+
* If the expected type `S` can be fully defined under ForceDegree.flipBottom,
1056+
* pick this one (this might use the fact that S <: T for an upper approximation).
1057+
* 3. Otherwise, if the target type `T` can be fully defined under ForceDegree.flipBottom,
1058+
* pick this one.
1059+
*
1060+
* If all attempts fail, issue a "missing parameter type" error.
1061+
*/
1062+
def inferredParamType(param: untpd.ValDef, formal: Type): Type =
1063+
if isFullyDefined(formal, ForceDegree.failBottom) then return formal
1064+
val target = calleeType.widen match
10501065
case mtpe: MethodType =>
10511066
val pos = paramIndex(param.name)
1052-
if (pos < mtpe.paramInfos.length) {
1067+
if pos < mtpe.paramInfos.length then
10531068
val ptype = mtpe.paramInfos(pos)
1054-
if (isFullyDefined(ptype, ForceDegree.noBottom) && !ptype.isRepeatedParam)
1055-
return ptype
1056-
}
1057-
case _ =>
1058-
}
1059-
errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.sourcePos)
1060-
}
1069+
if ptype.isRepeatedParam then NoType else ptype
1070+
else NoType
1071+
case _ => NoType
1072+
if target.exists then formal <:< target
1073+
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1074+
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
1075+
else errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.sourcePos)
10611076

10621077
def protoFormal(i: Int): Type =
10631078
if (protoFormals.length == params.length) protoFormals(i)
10641079
else errorType(WrongNumberOfParameters(protoFormals.length), tree.sourcePos)
10651080

10661081
/** Is `formal` a product type which is elementwise compatible with `params`? */
10671082
def ptIsCorrectProduct(formal: Type) =
1068-
isFullyDefined(formal, ForceDegree.noBottom) &&
1083+
isFullyDefined(formal, ForceDegree.flipBottom) &&
10691084
(defn.isProductSubType(formal) || formal.derivesFrom(defn.PairClass)) &&
10701085
productSelectorTypes(formal, tree.sourcePos).corresponds(params) {
10711086
(argType, param) =>
@@ -1379,7 +1394,7 @@ class Typer extends Namer
13791394
}
13801395
case _ =>
13811396
tree.withType(
1382-
if (isFullyDefined(pt, ForceDegree.noBottom)) pt
1397+
if (isFullyDefined(pt, ForceDegree.flipBottom)) pt
13831398
else if (ctx.reporter.errorsReported) UnspecifiedErrorType
13841399
else errorType(i"cannot infer type; expected type $pt is not fully defined", tree.sourcePos))
13851400
}
@@ -3054,7 +3069,7 @@ class Typer extends Namer
30543069
pt match {
30553070
case SAMType(sam)
30563071
if wtp <:< sam.toFunctionType() =>
3057-
// was ... && isFullyDefined(pt, ForceDegree.noBottom)
3072+
// was ... && isFullyDefined(pt, ForceDegree.flipBottom)
30583073
// but this prevents case blocks from implementing polymorphic partial functions,
30593074
// since we do not know the result parameter a priori. Have to wait until the
30603075
// body is typechecked.

tests/pos/i8111.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object Example extends App {
2+
3+
def assertLazy[A, B](f: (A) => B): Boolean = ???
4+
5+
def fromEither[E, F](eea: Either[E, F]): Unit = ???
6+
7+
lazy val result = assertLazy(fromEither)
8+
9+
println("It compiles!")
10+
}

0 commit comments

Comments
 (0)