Skip to content

Take HKT injectivity into account when inferring constraints #6461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ abstract class Constraint extends Showable {
/** The parameters that are known to be greater wrt <: than `param` */
def upper(param: TypeParamRef): List[TypeParamRef]

/** The lower dominator set.
*
* This is like `lower`, except that each parameter returned is no smaller than every other returned parameter.
*/
def minLower(param: TypeParamRef): List[TypeParamRef]

/** The upper dominator set.
*
* This is like `upper`, except that each parameter returned is no greater than every other returned parameter.
*/
def minUpper(param: TypeParamRef): List[TypeParamRef]

/** lower(param) \ lower(butNot) */
def exclusiveLower(param: TypeParamRef, butNot: TypeParamRef): List[TypeParamRef]

Expand All @@ -58,15 +70,6 @@ abstract class Constraint extends Showable {
*/
def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds

/** The lower bound of `param` including all known-to-be-smaller parameters */
def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type

/** The upper bound of `param` including all known-to-be-greater parameters */
def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type

/** The bounds of `param` including all known-to-be-smaller and -greater parameters */
def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds

/** A new constraint which is derived from this constraint by adding
* entries for all type parameters of `poly`.
* @param tvars A list of type variables associated with the params,
Expand Down
23 changes: 21 additions & 2 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package dotty.tools
package dotc
package core

import Types._, Contexts._, Symbols._
import Types._
import Contexts._
import Symbols._
import Decorators._
import config.Config
import config.Printers.{constr, typr}
import dotty.tools.dotc.reporting.trace

/** Methods for adding constraints and solving them.
*
Expand Down Expand Up @@ -66,6 +69,22 @@ trait ConstraintHandling[AbstractContext] {
case tp => tp
}

def nonParamBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = constraint.nonParamBounds(param)

def fullLowerBound(param: TypeParamRef)(implicit actx: AbstractContext): Type =
(nonParamBounds(param).lo /: constraint.minLower(param))(_ | _)

def fullUpperBound(param: TypeParamRef)(implicit actx: AbstractContext): Type =
(nonParamBounds(param).hi /: constraint.minUpper(param))(_ & _)

/** Full bounds of `param`, including other lower/upper params.
*
* Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds`
* of some param when comparing types might lead to infinite recursion. Consider `bounds` instead.
*/
def fullBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds =
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))

protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean =
!constraint.contains(param) || {
def occursIn(bound: Type): Boolean = {
Expand Down Expand Up @@ -262,7 +281,7 @@ trait ConstraintHandling[AbstractContext] {
}
constraint.entry(param) match {
case _: TypeBounds =>
val bound = if (fromBelow) constraint.fullLowerBound(param) else constraint.fullUpperBound(param)
val bound = if (fromBelow) fullLowerBound(param) else fullUpperBound(param)
val inst = avoidParam(bound)
typr_println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}")
inst
Expand Down
239 changes: 5 additions & 234 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ object Contexts {
final def importInfo: ImportInfo = _importInfo

/** The current bounds in force for type parameters appearing in a GADT */
private[this] var _gadt: GADTMap = _
protected def gadt_=(gadt: GADTMap): Unit = _gadt = gadt
final def gadt: GADTMap = _gadt
private[this] var _gadt: GadtConstraint = _
protected def gadt_=(gadt: GadtConstraint): Unit = _gadt = gadt
final def gadt: GadtConstraint = _gadt

/** The history of implicit searches that are currently active */
private[this] var _searchHistory: SearchHistory = null
Expand Down Expand Up @@ -534,7 +534,7 @@ object Contexts {
def setTypeAssigner(typeAssigner: TypeAssigner): this.type = { this.typeAssigner = typeAssigner; this }
def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) }
def setImportInfo(importInfo: ImportInfo): this.type = { this.importInfo = importInfo; this }
def setGadt(gadt: GADTMap): this.type = { this.gadt = gadt; this }
def setGadt(gadt: GadtConstraint): this.type = { this.gadt = gadt; this }
def setFreshGADTBounds: this.type = setGadt(gadt.fresh)
def setSearchHistory(searchHistory: SearchHistory): this.type = { this.searchHistory = searchHistory; this }
def setSource(source: SourceFile): this.type = { this.source = source; this }
Expand Down Expand Up @@ -617,7 +617,7 @@ object Contexts {
store = initialStore.updated(settingsStateLoc, settingsGroup.defaultState)
typeComparer = new TypeComparer(this)
searchHistory = new SearchRoot
gadt = EmptyGADTMap
gadt = EmptyGadtConstraint
}

@sharable object NoContext extends Context(null) {
Expand Down Expand Up @@ -774,233 +774,4 @@ object Contexts {
if (thread == null) thread = Thread.currentThread()
else assert(thread == Thread.currentThread(), "illegal multithreaded access to ContextBase")
}

sealed abstract class GADTMap {
def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean
def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds
def contains(sym: Symbol)(implicit ctx: Context): Boolean
def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type
def debugBoundsDescription(implicit ctx: Context): String
def fresh: GADTMap
def restore(other: GADTMap): Unit
def isEmpty: Boolean
}

final class SmartGADTMap private (
private var myConstraint: Constraint,
private var mapping: SimpleIdentityMap[Symbol, TypeVar],
private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
private var boundCache: SimpleIdentityMap[Symbol, TypeBounds]
) extends GADTMap with ConstraintHandling[Context] {
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}

def this() = this(
myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty),
mapping = SimpleIdentityMap.Empty,
reverseMapping = SimpleIdentityMap.Empty,
boundCache = SimpleIdentityMap.Empty
)

implicit override def ctx(implicit ctx: Context): Context = ctx

override protected def constraint = myConstraint
override protected def constraint_=(c: Constraint) = myConstraint = c

override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)

override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym)

override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = try {
boundCache = SimpleIdentityMap.Empty
boundAdditionInProgress = true
@annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match {
case tv: TypeVar =>
val inst = instType(tv)
if (inst.exists) stripInternalTypeVar(inst) else tv
case _ => tp
}

def externalizedSubtype(tp1: Type, tp2: Type, isSubtype: Boolean): Boolean = {
val externalizedTp1 = removeTypeVars(tp1)
val externalizedTp2 = removeTypeVars(tp2)

(
if (isSubtype) externalizedTp1 frozen_<:< externalizedTp2
else externalizedTp2 frozen_<:< externalizedTp1
).reporting({ res =>
val descr = i"$externalizedTp1 frozen_${if (isSubtype) "<:<" else ">:>"} $externalizedTp2"
i"$descr = $res"
}, gadts)
}

val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match {
case tv: TypeVar => tv
case inst =>
val externalizedInst = removeTypeVars(inst)
gadts.println(i"instantiated: $sym -> $externalizedInst")
return if (isUpper) isSubType(externalizedInst , bound) else isSubType(bound, externalizedInst)
}

val internalizedBound = insertTypeVars(bound)
(
stripInternalTypeVar(internalizedBound) match {
case boundTvar: TypeVar =>
if (boundTvar eq symTvar) true
else if (isUpper) addLess(symTvar.origin, boundTvar.origin)
else addLess(boundTvar.origin, symTvar.origin)
case bound =>
if (externalizedSubtype(symTvar, bound, isSubtype = !isUpper)) {
gadts.println(i"manually unifying $symTvar with $bound")
constraint = constraint.updateEntry(symTvar.origin, bound)
true
}
else if (isUpper) addUpperBound(symTvar.origin, bound)
else addLowerBound(symTvar.origin, bound)
}
).reporting({ res =>
val descr = if (isUpper) "upper" else "lower"
val op = if (isUpper) "<:" else ">:"
i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )"
}, gadts)
} finally boundAdditionInProgress = false

override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = {
mapping(sym) match {
case null => null
case tv =>
def retrieveBounds: TypeBounds = {
val tb = constraint.fullBounds(tv.origin)
removeTypeVars(tb).asInstanceOf[TypeBounds]
}
(
if (boundAdditionInProgress || ctx.mode.is(Mode.GADTflexible)) retrieveBounds
else boundCache(sym) match {
case tb: TypeBounds => tb
case null =>
val bounds = retrieveBounds
boundCache = boundCache.updated(sym, bounds)
bounds
}
).reporting({ res =>
// i"gadt bounds $sym: $res"
""
}, gadts)
}
}

override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null

override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = {
val res = removeTypeVars(approximation(tvar(sym).origin, fromBelow = fromBelow))
gadts.println(i"approximating $sym ~> $res")
res
}

override def fresh: GADTMap = new SmartGADTMap(
myConstraint,
mapping,
reverseMapping,
boundCache
)

def restore(other: GADTMap): Unit = other match {
case other: SmartGADTMap =>
this.myConstraint = other.myConstraint
this.mapping = other.mapping
this.reverseMapping = other.reverseMapping
this.boundCache = other.boundCache
case _ => ;
}

override def isEmpty: Boolean = mapping.size == 0

// ---- Private ----------------------------------------------------------

private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = {
mapping(sym) match {
case tv: TypeVar =>
tv
case null =>
val res = {
import NameKinds.DepParamName
// avoid registering the TypeVar with TyperState / TyperState#constraint
// - we don't want TyperState instantiating these TypeVars
// - we don't want TypeComparer constraining these TypeVars
val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)(
pt => (sym.info match {
case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb
case _ => TypeBounds.empty
}) :: Nil,
pt => defn.AnyType)
new TypeVar(poly.paramRefs.head, creatorState = null)
}
gadts.println(i"GADTMap: created tvar $sym -> $res")
constraint = constraint.add(res.origin.binder, res :: Nil)
mapping = mapping.updated(sym, res)
reverseMapping = reverseMapping.updated(res.origin, sym)
res
}
}

private def insertTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match {
case tp: TypeRef =>
val sym = tp.typeSymbol
if (contains(sym)) tvar(sym) else tp
case _ =>
(if (map != null) map else new TypeVarInsertingMap()).mapOver(tp)
}
private final class TypeVarInsertingMap(implicit ctx: Context) extends TypeMap {
override def apply(tp: Type): Type = insertTypeVars(tp, this)
}

private def removeTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match {
case tpr: TypeParamRef =>
reverseMapping(tpr) match {
case null => tpr
case sym => sym.typeRef
}
case tv: TypeVar =>
reverseMapping(tv.origin) match {
case null => tv
case sym => sym.typeRef
}
case _ =>
(if (map != null) map else new TypeVarRemovingMap()).mapOver(tp)
}
private final class TypeVarRemovingMap(implicit ctx: Context) extends TypeMap {
override def apply(tp: Type): Type = removeTypeVars(tp, this)
}

private[this] var boundAdditionInProgress = false

// ---- Debug ------------------------------------------------------------

override def constr_println(msg: => String): Unit = gadtsConstr.println(msg)

override def debugBoundsDescription(implicit ctx: Context): String = {
val sb = new mutable.StringBuilder
sb ++= constraint.show
sb += '\n'
mapping.foreachBinding { case (sym, _) =>
sb ++= i"$sym: ${bounds(sym)}\n"
}
sb.result
}
}

@sharable object EmptyGADTMap extends GADTMap {
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds")
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound")
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null
override def contains(sym: Symbol)(implicit ctx: Context) = false
override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation")
override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap"
override def fresh = new SmartGADTMap
override def restore(other: GADTMap): Unit = {
if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap")
}
override def isEmpty: Boolean = true
}
}
Loading