diff --git a/compiler/src/dotty/tools/dotc/core/Constraint.scala b/compiler/src/dotty/tools/dotc/core/Constraint.scala index 91bedf35948b..b40b806c85bb 100644 --- a/compiler/src/dotty/tools/dotc/core/Constraint.scala +++ b/compiler/src/dotty/tools/dotc/core/Constraint.scala @@ -45,6 +45,18 @@ abstract class Constraint extends Showable { /** The parameters that are known to be greater wrt <: than `param` */ def upper(param: TypeParamRef): List[TypeParamRef] + /** The lower dominator set. + * + * This is like `lower`, except that each parameter returned is no smaller than every other returned parameter. + */ + def minLower(param: TypeParamRef): List[TypeParamRef] + + /** The upper dominator set. + * + * This is like `upper`, except that each parameter returned is no greater than every other returned parameter. + */ + def minUpper(param: TypeParamRef): List[TypeParamRef] + /** lower(param) \ lower(butNot) */ def exclusiveLower(param: TypeParamRef, butNot: TypeParamRef): List[TypeParamRef] @@ -58,15 +70,6 @@ abstract class Constraint extends Showable { */ def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds - /** The lower bound of `param` including all known-to-be-smaller parameters */ - def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type - - /** The upper bound of `param` including all known-to-be-greater parameters */ - def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type - - /** The bounds of `param` including all known-to-be-smaller and -greater parameters */ - def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds - /** A new constraint which is derived from this constraint by adding * entries for all type parameters of `poly`. * @param tvars A list of type variables associated with the params, diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 0560866a3e6e..4afd55efdefb 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -2,10 +2,13 @@ package dotty.tools package dotc package core -import Types._, Contexts._, Symbols._ +import Types._ +import Contexts._ +import Symbols._ import Decorators._ import config.Config import config.Printers.{constr, typr} +import dotty.tools.dotc.reporting.trace /** Methods for adding constraints and solving them. * @@ -66,6 +69,22 @@ trait ConstraintHandling[AbstractContext] { case tp => tp } + def nonParamBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = constraint.nonParamBounds(param) + + def fullLowerBound(param: TypeParamRef)(implicit actx: AbstractContext): Type = + (nonParamBounds(param).lo /: constraint.minLower(param))(_ | _) + + def fullUpperBound(param: TypeParamRef)(implicit actx: AbstractContext): Type = + (nonParamBounds(param).hi /: constraint.minUpper(param))(_ & _) + + /** Full bounds of `param`, including other lower/upper params. + * + * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` + * of some param when comparing types might lead to infinite recursion. Consider `bounds` instead. + */ + def fullBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = + nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) + protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean = !constraint.contains(param) || { def occursIn(bound: Type): Boolean = { @@ -262,7 +281,7 @@ trait ConstraintHandling[AbstractContext] { } constraint.entry(param) match { case _: TypeBounds => - val bound = if (fromBelow) constraint.fullLowerBound(param) else constraint.fullUpperBound(param) + val bound = if (fromBelow) fullLowerBound(param) else fullUpperBound(param) val inst = avoidParam(bound) typr_println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}") inst diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 6e91741ae444..cbc3a12f0f8b 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -139,9 +139,9 @@ object Contexts { final def importInfo: ImportInfo = _importInfo /** The current bounds in force for type parameters appearing in a GADT */ - private[this] var _gadt: GADTMap = _ - protected def gadt_=(gadt: GADTMap): Unit = _gadt = gadt - final def gadt: GADTMap = _gadt + private[this] var _gadt: GadtConstraint = _ + protected def gadt_=(gadt: GadtConstraint): Unit = _gadt = gadt + final def gadt: GadtConstraint = _gadt /** The history of implicit searches that are currently active */ private[this] var _searchHistory: SearchHistory = null @@ -534,7 +534,7 @@ object Contexts { def setTypeAssigner(typeAssigner: TypeAssigner): this.type = { this.typeAssigner = typeAssigner; this } def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) } def setImportInfo(importInfo: ImportInfo): this.type = { this.importInfo = importInfo; this } - def setGadt(gadt: GADTMap): this.type = { this.gadt = gadt; this } + def setGadt(gadt: GadtConstraint): this.type = { this.gadt = gadt; this } def setFreshGADTBounds: this.type = setGadt(gadt.fresh) def setSearchHistory(searchHistory: SearchHistory): this.type = { this.searchHistory = searchHistory; this } def setSource(source: SourceFile): this.type = { this.source = source; this } @@ -617,7 +617,7 @@ object Contexts { store = initialStore.updated(settingsStateLoc, settingsGroup.defaultState) typeComparer = new TypeComparer(this) searchHistory = new SearchRoot - gadt = EmptyGADTMap + gadt = EmptyGadtConstraint } @sharable object NoContext extends Context(null) { @@ -774,233 +774,4 @@ object Contexts { if (thread == null) thread = Thread.currentThread() else assert(thread == Thread.currentThread(), "illegal multithreaded access to ContextBase") } - - sealed abstract class GADTMap { - def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit - def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean - def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds - def contains(sym: Symbol)(implicit ctx: Context): Boolean - def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type - def debugBoundsDescription(implicit ctx: Context): String - def fresh: GADTMap - def restore(other: GADTMap): Unit - def isEmpty: Boolean - } - - final class SmartGADTMap private ( - private var myConstraint: Constraint, - private var mapping: SimpleIdentityMap[Symbol, TypeVar], - private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - private var boundCache: SimpleIdentityMap[Symbol, TypeBounds] - ) extends GADTMap with ConstraintHandling[Context] { - import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} - - def this() = this( - myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), - mapping = SimpleIdentityMap.Empty, - reverseMapping = SimpleIdentityMap.Empty, - boundCache = SimpleIdentityMap.Empty - ) - - implicit override def ctx(implicit ctx: Context): Context = ctx - - override protected def constraint = myConstraint - override protected def constraint_=(c: Constraint) = myConstraint = c - - override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) - override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) - - override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym) - - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = try { - boundCache = SimpleIdentityMap.Empty - boundAdditionInProgress = true - @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { - case tv: TypeVar => - val inst = instType(tv) - if (inst.exists) stripInternalTypeVar(inst) else tv - case _ => tp - } - - def externalizedSubtype(tp1: Type, tp2: Type, isSubtype: Boolean): Boolean = { - val externalizedTp1 = removeTypeVars(tp1) - val externalizedTp2 = removeTypeVars(tp2) - - ( - if (isSubtype) externalizedTp1 frozen_<:< externalizedTp2 - else externalizedTp2 frozen_<:< externalizedTp1 - ).reporting({ res => - val descr = i"$externalizedTp1 frozen_${if (isSubtype) "<:<" else ">:>"} $externalizedTp2" - i"$descr = $res" - }, gadts) - } - - val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match { - case tv: TypeVar => tv - case inst => - val externalizedInst = removeTypeVars(inst) - gadts.println(i"instantiated: $sym -> $externalizedInst") - return if (isUpper) isSubType(externalizedInst , bound) else isSubType(bound, externalizedInst) - } - - val internalizedBound = insertTypeVars(bound) - ( - stripInternalTypeVar(internalizedBound) match { - case boundTvar: TypeVar => - if (boundTvar eq symTvar) true - else if (isUpper) addLess(symTvar.origin, boundTvar.origin) - else addLess(boundTvar.origin, symTvar.origin) - case bound => - if (externalizedSubtype(symTvar, bound, isSubtype = !isUpper)) { - gadts.println(i"manually unifying $symTvar with $bound") - constraint = constraint.updateEntry(symTvar.origin, bound) - true - } - else if (isUpper) addUpperBound(symTvar.origin, bound) - else addLowerBound(symTvar.origin, bound) - } - ).reporting({ res => - val descr = if (isUpper) "upper" else "lower" - val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )" - }, gadts) - } finally boundAdditionInProgress = false - - override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { - mapping(sym) match { - case null => null - case tv => - def retrieveBounds: TypeBounds = { - val tb = constraint.fullBounds(tv.origin) - removeTypeVars(tb).asInstanceOf[TypeBounds] - } - ( - if (boundAdditionInProgress || ctx.mode.is(Mode.GADTflexible)) retrieveBounds - else boundCache(sym) match { - case tb: TypeBounds => tb - case null => - val bounds = retrieveBounds - boundCache = boundCache.updated(sym, bounds) - bounds - } - ).reporting({ res => - // i"gadt bounds $sym: $res" - "" - }, gadts) - } - } - - override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null - - override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { - val res = removeTypeVars(approximation(tvar(sym).origin, fromBelow = fromBelow)) - gadts.println(i"approximating $sym ~> $res") - res - } - - override def fresh: GADTMap = new SmartGADTMap( - myConstraint, - mapping, - reverseMapping, - boundCache - ) - - def restore(other: GADTMap): Unit = other match { - case other: SmartGADTMap => - this.myConstraint = other.myConstraint - this.mapping = other.mapping - this.reverseMapping = other.reverseMapping - this.boundCache = other.boundCache - case _ => ; - } - - override def isEmpty: Boolean = mapping.size == 0 - - // ---- Private ---------------------------------------------------------- - - private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = { - mapping(sym) match { - case tv: TypeVar => - tv - case null => - val res = { - import NameKinds.DepParamName - // avoid registering the TypeVar with TyperState / TyperState#constraint - // - we don't want TyperState instantiating these TypeVars - // - we don't want TypeComparer constraining these TypeVars - val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)( - pt => (sym.info match { - case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb - case _ => TypeBounds.empty - }) :: Nil, - pt => defn.AnyType) - new TypeVar(poly.paramRefs.head, creatorState = null) - } - gadts.println(i"GADTMap: created tvar $sym -> $res") - constraint = constraint.add(res.origin.binder, res :: Nil) - mapping = mapping.updated(sym, res) - reverseMapping = reverseMapping.updated(res.origin, sym) - res - } - } - - private def insertTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match { - case tp: TypeRef => - val sym = tp.typeSymbol - if (contains(sym)) tvar(sym) else tp - case _ => - (if (map != null) map else new TypeVarInsertingMap()).mapOver(tp) - } - private final class TypeVarInsertingMap(implicit ctx: Context) extends TypeMap { - override def apply(tp: Type): Type = insertTypeVars(tp, this) - } - - private def removeTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match { - case tpr: TypeParamRef => - reverseMapping(tpr) match { - case null => tpr - case sym => sym.typeRef - } - case tv: TypeVar => - reverseMapping(tv.origin) match { - case null => tv - case sym => sym.typeRef - } - case _ => - (if (map != null) map else new TypeVarRemovingMap()).mapOver(tp) - } - private final class TypeVarRemovingMap(implicit ctx: Context) extends TypeMap { - override def apply(tp: Type): Type = removeTypeVars(tp, this) - } - - private[this] var boundAdditionInProgress = false - - // ---- Debug ------------------------------------------------------------ - - override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) - - override def debugBoundsDescription(implicit ctx: Context): String = { - val sb = new mutable.StringBuilder - sb ++= constraint.show - sb += '\n' - mapping.foreachBinding { case (sym, _) => - sb ++= i"$sym: ${bounds(sym)}\n" - } - sb.result - } - } - - @sharable object EmptyGADTMap extends GADTMap { - override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds") - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound") - override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null - override def contains(sym: Symbol)(implicit ctx: Context) = false - override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation") - override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap" - override def fresh = new SmartGADTMap - override def restore(other: GADTMap): Unit = { - if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap") - } - override def isEmpty: Boolean = true - } } diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala new file mode 100644 index 000000000000..f4a30c8391a1 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -0,0 +1,323 @@ +package dotty.tools +package dotc +package core + +import Decorators._ +import Contexts._ +import Types._ +import Symbols._ +import util.SimpleIdentityMap +import collection.mutable +import printing._ + +import scala.annotation.internal.sharable + +/** Represents GADT constraints currently in scope */ +sealed abstract class GadtConstraint extends Showable { + /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ + def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds + + /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. + * + * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` + * of some symbol when comparing types might lead to infinite recursion. Consider `bounds` instead. + */ + def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds + + /** Is `sym1` ordered to be less than `sym2`? */ + def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean + + /** Add symbols to constraint, preserving the underlying bounds and handling inter-dependencies. */ + def addToConstraint(syms: List[Symbol])(implicit ctx: Context): Boolean + def addToConstraint(sym: Symbol)(implicit ctx: Context): Boolean = addToConstraint(sym :: Nil) + + /** Further constrain a symbol already present in the constraint. */ + def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean + + /** Is the symbol registered in the constraint? + * + * Note that this is returns `true` even if `sym` is already instantiated to some type, + * unlike [[Constraint.contains]]. + */ + def contains(sym: Symbol)(implicit ctx: Context): Boolean + + def isEmpty: Boolean + + /** See [[ConstraintHandling.approximation]] */ + def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type + + def fresh: GadtConstraint + + /** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */ + def restore(other: GadtConstraint): Unit + + def debugBoundsDescription(implicit ctx: Context): String +} + +final class ProperGadtConstraint private( + private var myConstraint: Constraint, + private var mapping: SimpleIdentityMap[Symbol, TypeVar], + private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], +) extends GadtConstraint with ConstraintHandling[Context] { + import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} + + def this() = this( + myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), + mapping = SimpleIdentityMap.Empty, + reverseMapping = SimpleIdentityMap.Empty + ) + + /** Exposes ConstraintHandling.subsumes */ + def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(implicit ctx: Context): Boolean = { + def extractConstraint(g: GadtConstraint) = g match { + case s: ProperGadtConstraint => s.constraint + case EmptyGadtConstraint => OrderingConstraint.empty + } + subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) + } + + override def addToConstraint(params: List[Symbol])(implicit ctx: Context): Boolean = { + import NameKinds.DepParamName + + val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })( + pt => params.map { param => + // replace the symbols in bound type `tp` which are in dependent positions + // with their internal TypeParamRefs + def substDependentSyms(tp: Type, isUpper: Boolean)(implicit ctx: Context): Type = { + def loop(tp: Type) = substDependentSyms(tp, isUpper) + tp match { + case tp @ AndType(tp1, tp2) if !isUpper => + tp.derivedAndType(loop(tp1), loop(tp2)) + case tp @ OrType(tp1, tp2) if isUpper => + tp.derivedOrType(loop(tp1), loop(tp2)) + case tp: NamedType => + params.indexOf(tp.symbol) match { + case -1 => + mapping(tp.symbol) match { + case tv: TypeVar => tv.origin + case null => tp + } + case i => pt.paramRefs(i) + } + case tp => tp + } + } + + val tb = param.info.bounds + tb.derivedTypeBounds( + lo = substDependentSyms(tb.lo, isUpper = false), + hi = substDependentSyms(tb.hi, isUpper = true) + ) + }, + pt => defn.AnyType + ) + + val tvars = (params, poly1.paramRefs).zipped.map { (sym, paramRef) => + val tv = new TypeVar(paramRef, creatorState = null) + mapping = mapping.updated(sym, tv) + reverseMapping = reverseMapping.updated(tv.origin, sym) + tv + } + + // the replaced symbols will be stripped off the bounds by `addToConstraint` and used as orderings + addToConstraint(poly1, tvars).reporting({ _ => + i"added to constraint: $params%, %\n$debugBoundsDescription" + }, gadts) + } + + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = { + @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { + case tv: TypeVar => + val inst = instType(tv) + if (inst.exists) stripInternalTypeVar(inst) else tv + case _ => tp + } + + val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match { + case tv: TypeVar => tv + case inst => + gadts.println(i"instantiated: $sym -> $inst") + return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst) + } + + val internalizedBound = bound match { + case nt: NamedType => + val ntTvar = mapping(nt.symbol) + if (ntTvar ne null) stripInternalTypeVar(ntTvar) else bound + case _ => bound + } + ( + internalizedBound match { + case boundTvar: TypeVar => + if (boundTvar eq symTvar) true + else if (isUpper) addLess(symTvar.origin, boundTvar.origin) + else addLess(boundTvar.origin, symTvar.origin) + case bound => + val oldUpperBound = bounds(symTvar.origin) + // If we have bounds: + // F >: [t] => List[t] <: [t] => Any + // and we want to record that: + // F <: [+A] => List[A] + // we need to adapt the variance and instead record that: + // F <: [A] => List[A] + // We cannot record the original bound, since it is false that: + // [t] => List[t] <: [+A] => List[A] + // + // Note that the following code is accepted: + // class Foo[F[t] >: List[t]] + // type T = Foo[List] + // precisely because Foo[List] is desugared to Foo[[A] => List[A]]. + val bound1 = bound.adaptHkVariances(oldUpperBound) + if (isUpper) addUpperBound(symTvar.origin, bound1) + else addLowerBound(symTvar.origin, bound1) + } + ).reporting({ res => + val descr = if (isUpper) "upper" else "lower" + val op = if (isUpper) "<:" else ">:" + i"adding $descr bound $sym $op $bound = $res" + }, gadts) + } + + override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = + constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) + + override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = + mapping(sym) match { + case null => null + case tv => + fullBounds(tv.origin) + .ensuring(containsNoInternalTypes(_)) + } + + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { + mapping(sym) match { + case null => null + case tv => + def retrieveBounds: TypeBounds = + bounds(tv.origin) match { + case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => + TypeAlias(reverseMapping(tpr).typeRef) + case tb => tb + } + retrieveBounds + //.reporting({ res => i"gadt bounds $sym: $res" }, gadts) + .ensuring(containsNoInternalTypes(_)) + } + } + + override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null + + override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { + val res = approximation(tvarOrError(sym).origin, fromBelow = fromBelow) + gadts.println(i"approximating $sym ~> $res") + res + } + + override def fresh: GadtConstraint = new ProperGadtConstraint( + myConstraint, + mapping, + reverseMapping + ) + + def restore(other: GadtConstraint): Unit = other match { + case other: ProperGadtConstraint => + this.myConstraint = other.myConstraint + this.mapping = other.mapping + this.reverseMapping = other.reverseMapping + case _ => ; + } + + override def isEmpty: Boolean = mapping.size == 0 + + // ---- Protected/internal ----------------------------------------------- + + implicit override def ctx(implicit ctx: Context): Context = ctx + + override protected def constraint = myConstraint + override protected def constraint_=(c: Constraint) = myConstraint = c + + override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) + override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) + + override def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = + constraint.nonParamBounds(param) match { + case TypeAlias(tpr: TypeParamRef) => TypeAlias(externalize(tpr)) + case tb => tb + } + + override def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type = + (nonParamBounds(param).lo /: constraint.minLower(param)) { + (t, u) => t | externalize(u) + } + + override def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type = + (nonParamBounds(param).hi /: constraint.minUpper(param)) { + (t, u) => t & externalize(u) + } + + // ---- Private ---------------------------------------------------------- + + private[this] def externalize(param: TypeParamRef)(implicit ctx: Context): Type = + reverseMapping(param) match { + case sym: Symbol => sym.typeRef + case null => param + } + + private[this] def tvarOrError(sym: Symbol)(implicit ctx: Context): TypeVar = + mapping(sym).ensuring(_ ne null, i"not a constrainable symbol: $sym") + + private[this] def containsNoInternalTypes( + tp: Type, + acc: TypeAccumulator[Boolean] = null + )(implicit ctx: Context): Boolean = tp match { + case tpr: TypeParamRef => !reverseMapping.contains(tpr) + case tv: TypeVar => !reverseMapping.contains(tv.origin) + case tp => + (if (acc ne null) acc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp) + } + + private[this] class ContainsNoInternalTypesAccumulator(implicit ctx: Context) extends TypeAccumulator[Boolean] { + override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp) + } + + // ---- Debug ------------------------------------------------------------ + + override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) + + override def toText(printer: Printer): Texts.Text = constraint.toText(printer) + + override def debugBoundsDescription(implicit ctx: Context): String = { + val sb = new mutable.StringBuilder + sb ++= constraint.show + sb += '\n' + mapping.foreachBinding { case (sym, _) => + sb ++= i"$sym: ${fullBounds(sym)}\n" + } + sb.result + } +} + +@sharable object EmptyGadtConstraint extends GadtConstraint { + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null + override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null + + override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.isLess") + + override def isEmpty: Boolean = true + + override def contains(sym: Symbol)(implicit ctx: Context) = false + + override def addToConstraint(params: List[Symbol])(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addBound") + + override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGadtConstraint.approximation") + + override def fresh = new ProperGadtConstraint + override def restore(other: GadtConstraint): Unit = { + if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap") + } + + override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGadtConstraint" + + override def toText(printer: Printer): Texts.Text = "EmptyGadtConstraint" +} diff --git a/compiler/src/dotty/tools/dotc/core/Mode.scala b/compiler/src/dotty/tools/dotc/core/Mode.scala index 430d0b062c84..81b9fc5ea5c4 100644 --- a/compiler/src/dotty/tools/dotc/core/Mode.scala +++ b/compiler/src/dotty/tools/dotc/core/Mode.scala @@ -49,7 +49,7 @@ object Mode { /** We are in a pattern alternative */ val InPatternAlternative: Mode = newMode(7, "InPatternAlternative") - /** Allow GADTFlexType labelled types to have their bounds adjusted */ + /** Infer GADT constraints during type comparisons `A <:< B` */ val GADTflexible: Mode = newMode(8, "GADTflexible") /** Assume -language:strictEquality */ diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 2f568dfe7750..869c8330a5a3 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -196,15 +196,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds, def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = entry(param).bounds - def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type = - (nonParamBounds(param).lo /: minLower(param))(_ | _) - - def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type = - (nonParamBounds(param).hi /: minUpper(param))(_ & _) - - def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = - nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) - def typeVarOfParam(param: TypeParamRef): Type = { val entries = boundsMap(param.binder) if (entries == null) NoType diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index ef81b8bb3bf9..b3d41754dc87 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -209,16 +209,10 @@ trait Symbols { this: Context => modFlags | PackageCreationFlags, clsFlags | PackageCreationFlags, Nil, decls) - /** Define a new symbol associated with a Bind or pattern wildcard and - * make it gadt narrowable. - */ - def newPatternBoundSymbol(name: Name, info: Type, span: Span): Symbol = { + /** Define a new symbol associated with a Bind or pattern wildcard and, by default, make it gadt narrowable. */ + def newPatternBoundSymbol(name: Name, info: Type, span: Span, addToGadt: Boolean = true): Symbol = { val sym = newSymbol(owner, name, Case, info, coord = span) - if (name.isTypeName) { - val bounds = info.bounds - gadt.addBound(sym, bounds.lo, isUpper = false) - gadt.addBound(sym, bounds.hi, isUpper = true) - } + if (addToGadt && name.isTypeName) gadt.addToConstraint(sym) sym } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 8a69184f0846..0c9d6416e35d 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -149,6 +149,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { this.leftRoot = tp1 } else this.approx = a + if (savedApprox.gadt) { + this.approx = this.approx.addGadt + } try recur(tp1, tp2) catch { case ex: Throwable => handleRecursive("subtype", i"$tp1 <:< $tp2", ex, weight = 2) @@ -442,8 +445,18 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { val gbounds2 = gadtBounds(tp2.symbol) (gbounds2 != null) && (isSubTypeWhenFrozen(tp1, gbounds2.lo) || + (tp1 match { + case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => + // Note: since we approximate constrained types only with their non-param bounds, + // we need to manually handle the case when we're comparing two constrained types, + // one of which is constrained to be a subtype of another. + // We do not need similar code in fourthTry, since we only need to care about + // comparing two constrained types, and that case will be handled here first. + ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + case _ => false + }) || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) && - GADTusage(tp2.symbol) + { tp1.isRef(NothingClass) || GADTusage(tp2.symbol) } } isSubApproxHi(tp1, info2.lo) || compareGADT || fourthTry @@ -702,7 +715,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { (gbounds1 != null) && (isSubTypeWhenFrozen(gbounds1.hi, tp2) || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) && - GADTusage(tp1.symbol) + { tp2.isRef(AnyClass) || GADTusage(tp1.symbol) } } isSubType(hi1, tp2, approx.addLow) || compareGADT case _ => @@ -830,8 +843,10 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { gadtBoundsContain(tycon1sym, tycon2) || gadtBoundsContain(tycon2sym, tycon1) ) && - isSubType(tycon1.prefix, tycon2.prefix) && - isSubArgs(args1, args2, tp1, tparams) + isSubType(tycon1.prefix, tycon2.prefix) && { + val tyconIsInjective = tycon1sym.isClass || tycon2sym.isClass + isSubArgs(args1, args2, tp1, tparams, tyconIsInjective) + } if (res && touchedGADTs) GADTused = true res case _ => @@ -1087,7 +1102,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { * @param tp1 The applied type containing `args1` * @param tparams2 The type parameters of the type constructor applied to `args2` */ - def isSubArgs(args1: List[Type], args2: List[Type], tp1: Type, tparams2: List[ParamInfo]): Boolean = { + def isSubArgs(args1: List[Type], args2: List[Type], tp1: Type, tparams2: List[ParamInfo], inferGadtBounds: Boolean = false): Boolean = { /** The bounds of parameter `tparam`, where all references to type paramneters * are replaced by corresponding arguments (or their approximations in the case of * wildcard arguments). @@ -1151,8 +1166,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { case arg1: TypeBounds => compareCaptured(arg1, arg2) case _ => - (v > 0 || isSubType(arg2, arg1)) && - (v < 0 || isSubType(arg1, arg2)) + var nextApprox = if (inferGadtBounds) FreshApprox else FreshApprox.addGadt + (v > 0 || isSubType(arg2, arg1, nextApprox)) && + (v < 0 || isSubType(arg1, arg2, nextApprox)) } } @@ -1240,16 +1256,52 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { */ private def either(op1: => Boolean, op2: => Boolean): Boolean = { val preConstraint = constraint - op1 && { - val leftConstraint = constraint - constraint = preConstraint - if (!(op2 && subsumes(leftConstraint, constraint, preConstraint))) { - if (constr != noPrinter && !subsumes(constraint, leftConstraint, preConstraint)) - constr.println(i"CUT - prefer $leftConstraint over $constraint") - constraint = leftConstraint - } - true - } || op2 + + if (ctx.mode.is(Mode.GADTflexible)) { + val preGadt = ctx.gadt.fresh + // if GADTflexible mode is on, we always have a ProperGadtConstraint + val pre = preGadt.asInstanceOf[ProperGadtConstraint] + if (op1) { + val leftConstraint = constraint + val leftGadt = ctx.gadt.fresh + constraint = preConstraint + ctx.gadt.restore(preGadt) + if (op2) { + if (pre.subsumes(leftGadt, ctx.gadt, preGadt) && subsumes(leftConstraint, constraint, preConstraint)) { + gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $leftGadt") + constr.println(i"CUT - prefer $constraint over $leftConstraint") + true + } else if (pre.subsumes(ctx.gadt, leftGadt, preGadt) && subsumes(constraint, leftConstraint, preConstraint)) { + gadts.println(i"GADT CUT - prefer $leftGadt over ${ctx.gadt}") + constr.println(i"CUT - prefer $leftConstraint over $constraint") + constraint = leftConstraint + ctx.gadt.restore(leftGadt) + true + } else { + gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt") + constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint") + constraint = preConstraint + ctx.gadt.restore(preGadt) + true + } + } else { + constraint = leftConstraint + ctx.gadt.restore(leftGadt) + true + } + } else op2 + } else { + op1 && { + val leftConstraint = constraint + constraint = preConstraint + if (!(op2 && subsumes(leftConstraint, constraint, preConstraint))) { + if (constr != noPrinter && !subsumes(constraint, leftConstraint, preConstraint)) + constr.println(i"CUT - prefer $leftConstraint over $constraint") + constraint = leftConstraint + } + true + } || op2 + } } /** Does type `tp1` have a member with name `name` whose normalized type is a subtype of @@ -1364,7 +1416,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { * Test that the resulting bounds are still satisfiable. */ private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { - val boundImprecise = if (isUpper) approx.high else approx.low + val boundImprecise = approx.high || approx.low || approx.gadt ctx.mode.is(Mode.GADTflexible) && !frozenConstraint && !boundImprecise && { val tparam = tr.symbol gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") @@ -2100,6 +2152,7 @@ object TypeComparer { private val LoApprox = 1 private val HiApprox = 2 + private val GadtApprox = 4 /** The approximation state indicates how the pair of types currently compared * relates to the types compared originally. @@ -2109,14 +2162,17 @@ object TypeComparer { */ class ApproxState(private val bits: Int) extends AnyVal { override def toString: String = { - val lo = if ((bits & LoApprox) != 0) "LoApprox" else "" - val hi = if ((bits & HiApprox) != 0) "HiApprox" else "" - lo ++ hi + val lo = if (low) "LoApprox" else "" + val hi = if (high) "HiApprox" else "" + val g = if (gadt) "GadtApprox" else "" + lo ++ hi ++ g } def addLow: ApproxState = new ApproxState(bits | LoApprox) def addHigh: ApproxState = new ApproxState(bits | HiApprox) + def addGadt: ApproxState = new ApproxState(bits | GadtApprox) def low: Boolean = (bits & LoApprox) != 0 def high: Boolean = (bits & HiApprox) != 0 + def gadt: Boolean = (bits & GadtApprox) != 0 } val NoApprox: ApproxState = new ApproxState(0) @@ -2125,7 +2181,7 @@ object TypeComparer { * compare (approximations of) this pair of types. It's converted to `NoApprox` * in `isSubType`, but also leads to `leftRoot` being set there. */ - val FreshApprox: ApproxState = new ApproxState(4) + val FreshApprox: ApproxState = new ApproxState(1 << 31) /** Show trace of comparison operations when performing `op` as result string */ def explaining[T](say: String => Unit)(op: Context => T)(implicit ctx: Context): T = { diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index f41a710dcde5..664016b99f38 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -387,7 +387,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object. val bound1 = massage(bound) if (bound1 ne bound) { if (checkCtx eq ctx) checkCtx = ctx.fresh.setFreshGADTBounds - if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addEmptyBounds(sym) + if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addToConstraint(sym) checkCtx.gadt.addBound(sym, bound1, fromBelow) typr.println("install GADT bound $bound1 for when checking F-bounded $sym") } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 49908bacdcd6..33efe736b3c8 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -3704,7 +3704,12 @@ object Types { // ----- Skolem types ----------------------------------------------- - /** A skolem type reference with underlying type `info` */ + /** A skolem type reference with underlying type `info`. + * + * For Dotty, a skolem type is a singleton type of some unknown value of type `info`. + * Note that care is needed when creating them, since not all types need to be inhabited. + * A skolem is equal to itself and no other type. + */ case class SkolemType(info: Type) extends UncachedProxyType with ValueType with SingletonType { override def underlying(implicit ctx: Context): Type = info def derivedSkolemType(info: Type)(implicit ctx: Context): SkolemType = @@ -3863,10 +3868,10 @@ object Types { def contextInfo(tp: Type): Type = tp match { case tp: TypeParamRef => val constraint = ctx.typerState.constraint - if (constraint.entry(tp).exists) constraint.fullBounds(tp) + if (constraint.entry(tp).exists) ctx.typeComparer.fullBounds(tp) else NoType case tp: TypeRef => - val bounds = ctx.gadt.bounds(tp.symbol) + val bounds = ctx.gadt.fullBounds(tp.symbol) if (bounds == null) NoType else bounds case tp: TypeVar => tp.underlying diff --git a/compiler/src/dotty/tools/dotc/printing/Formatting.scala b/compiler/src/dotty/tools/dotc/printing/Formatting.scala index 6b6a6845565a..408d369d84a4 100644 --- a/compiler/src/dotty/tools/dotc/printing/Formatting.scala +++ b/compiler/src/dotty/tools/dotc/printing/Formatting.scala @@ -170,7 +170,7 @@ object Formatting { case sym: Symbol => val info = if (ctx.gadt.contains(sym)) - sym.info & ctx.gadt.bounds(sym) + sym.info & ctx.gadt.fullBounds(sym) else sym.info s"is a ${ctx.printer.kindString(sym)}${sym.showExtendedLocation}${addendum("bounds", info)}" @@ -190,7 +190,7 @@ object Formatting { case param: TermParamRef => false case skolem: SkolemType => true case sym: Symbol => - ctx.gadt.contains(sym) && ctx.gadt.bounds(sym) != TypeBounds.empty + ctx.gadt.contains(sym) && ctx.gadt.fullBounds(sym) != TypeBounds.empty case _ => assert(false, "unreachable") false diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 9efd80c1424c..844533376725 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -208,7 +208,10 @@ class PlainPrinter(_ctx: Context) extends Printer { else { val constr = ctx.typerState.constraint val bounds = - if (constr.contains(tp)) constr.fullBounds(tp.origin)(ctx.addMode(Mode.Printing)) + if (constr.contains(tp)) { + val ctx0 = ctx.addMode(Mode.Printing) + ctx0.typeComparer.fullBounds(tp.origin) + } else TypeBounds.empty if (bounds.isTypeAlias) toText(bounds.lo) ~ (Str("^") provided ctx.settings.YprintDebug.value) else if (ctx.settings.YshowVarBounds.value) "(" ~ toText(tp.origin) ~ "?" ~ toText(bounds) ~ ")" diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index e4d71a68488a..375d11bedc30 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -401,9 +401,9 @@ class TreeChecker extends Phase with SymTransformer { } } - override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = { + override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = { withPatSyms(tpd.patVars(tree.pat.asInstanceOf[tpd.Tree])) { - super.typedCase(tree, selType, pt, gadtSyms) + super.typedCase(tree, selType, pt) } } diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index 03b67a7a91fe..bd07958cbb91 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -128,8 +128,8 @@ object ErrorReporting { case tp: TypeParamRef => constraint.entry(tp) match { case bounds: TypeBounds => - if (variance < 0) apply(constraint.fullUpperBound(tp)) - else if (variance > 0) apply(constraint.fullLowerBound(tp)) + if (variance < 0) apply(ctx.typeComparer.fullUpperBound(tp)) + else if (variance > 0) apply(ctx.typeComparer.fullLowerBound(tp)) else tp case NoType => tp case instType => apply(instType) diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 1fa53e79582a..cdc68f2214b7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -345,7 +345,7 @@ object Implicits { * @param level The level where the reference was found * @param tstate The typer state to be committed if this alternative is chosen */ - case class SearchSuccess(tree: Tree, ref: TermRef, level: Int)(val tstate: TyperState, val gstate: GADTMap) extends SearchResult with Showable + case class SearchSuccess(tree: Tree, ref: TermRef, level: Int)(val tstate: TyperState, val gstate: GadtConstraint) extends SearchResult with Showable /** A failed search */ case class SearchFailure(tree: Tree) extends SearchResult { @@ -397,21 +397,29 @@ object Implicits { * what was expected */ override def clarify(tp: Type)(implicit ctx: Context): Type = { - val map = new TypeMap { - def apply(t: Type): Type = t match { - case t: TypeParamRef => - constraint.entry(t) match { - case NoType => t - case bounds: TypeBounds => constraint.fullBounds(t) - case t1 => t1 - } - case t: TypeVar => - t.instanceOpt.orElse(apply(t.origin)) - case _ => - mapOver(t) + val ctx0 = ctx + locally { + implicit val ctx = ctx0.fresh.setTyperState { + val ts = ctx0.typerState.fresh() + ts.constraint_=(constraint)(ctx0) + ts + } + val map = new TypeMap { + def apply(t: Type): Type = t match { + case t: TypeParamRef => + constraint.entry(t) match { + case NoType => t + case bounds: TypeBounds => ctx.typeComparer.fullBounds(t) + case t1 => t1 + } + case t: TypeVar => + t.instanceOpt.orElse(apply(t.origin)) + case _ => + mapOver(t) + } } + map(tp) } - map(tp) } def explanation(implicit ctx: Context): String = diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index eb0674802cf6..05d8a7568acc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -184,8 +184,6 @@ object Inferencing { * * Invariant refinement can be assumed if `PatternType`'s class(es) are final or * case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`). - * - * TODO: Update so that GADT symbols can be variant, and we special case final class types in patterns */ def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = { def refinementIsInvariant(tp: Type): Boolean = tp match { @@ -209,8 +207,9 @@ object Inferencing { } val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt) - trace(i"constraining pattern type $tp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { - tp <:< widePt + val narrowTp = SkolemType(tp) + trace(i"constraining pattern type $narrowTp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { + narrowTp <:< widePt } } @@ -263,7 +262,7 @@ object Inferencing { * 0 if unconstrained, or constraint is from below and above. */ private def instDirection(param: TypeParamRef)(implicit ctx: Context): Int = { - val constrained = ctx.typerState.constraint.fullBounds(param) + val constrained = ctx.typeComparer.fullBounds(param) val original = param.binder.paramInfos(param.paramNum) val cmp = ctx.typeComparer val approxBelow = @@ -298,17 +297,21 @@ object Inferencing { if (v == 1) tvar.instantiate(fromBelow = false) else if (v == -1) tvar.instantiate(fromBelow = true) else { - val bounds = ctx.typerState.constraint.fullBounds(tvar.origin) + val bounds = ctx.typeComparer.fullBounds(tvar.origin) if (bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) || fromScala2x) tvar.instantiate(fromBelow = false) else { - val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span) + // since the symbols we're creating may have inter-dependencies in their bounds, + // we add them to the GADT constraint later, simultaneously + val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span, addToGadt = false) tvar.instantiateWith(wildCard.typeRef) patternBound += wildCard } } } - patternBound.toList + val res = patternBound.toList + if (res.nonEmpty) ctx.gadt.addToConstraint(res) + res } type VarianceMap = SimpleIdentityMap[TypeVar, Integer] diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 193a92f0000a..f1ff0070469c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -534,6 +534,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { /** An extractor for terms equivalent to `new C(args)`, returning the class `C`, * a list of bindings, and the arguments `args`. Can see inside blocks and Inlined nodes and can * follow a reference to an inline value binding to its right hand side. + * * @return optionally, a triple consisting of * - the class `C` * - the arguments `args` @@ -729,7 +730,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = { val isImplicit = scrutinee.isEmpty - val gadtSyms = typer.gadtSyms(scrutType) /** Try to match pattern `pat` against scrutinee reference `scrut`. If successful add * bindings for variables bound in this pattern to `caseBindingMap`. @@ -821,11 +821,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } def registerAsGadtSyms(typeBinds: TypeBindsMap)(implicit ctx: Context): Unit = - typeBinds.foreachBinding { case (sym, _) => - val TypeBounds(lo, hi) = sym.info.bounds - ctx.gadt.addBound(sym, lo, isUpper = false) - ctx.gadt.addBound(sym, hi, isUpper = true) - } + if (typeBinds.size > 0) ctx.gadt.addToConstraint(typeBinds.keys) pat match { case Typed(pat1, tpt) => @@ -920,7 +916,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } if (!isImplicit) caseBindingMap += ((NoSymbol, scrutineeBinding)) - val gadtCtx = typer.gadtContext(gadtSyms).addMode(Mode.GADTflexible) + val gadtCtx = ctx.fresh.setFreshGADTBounds.addMode(Mode.GADTflexible) if (reducePattern(caseBindingMap, scrutineeSym.termRef, cdef.pat)(gadtCtx)) { val (caseBindings, from, to) = substBindings(caseBindingMap.toList, mutable.ListBuffer(), Nil, Nil) val guardOK = cdef.guard.isEmpty || { diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index bbf1541fe222..ba38ea9b75bc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1335,8 +1335,15 @@ class Namer { typer: Typer => // it would be erased to BoxedUnit. def dealiasIfUnit(tp: Type) = if (tp.isRef(defn.UnitClass)) defn.UnitType else tp - var rhsCtx = ctx.addMode(Mode.InferringReturnType) + var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType) if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody) + if (typeParams.nonEmpty) { + // we'll be typing an expression from a polymorphic definition's body, + // so we must allow constraining its type parameters + // compare with typedDefDef, see tests/pos/gadt-inference.scala + rhsCtx.setFreshGADTBounds + rhsCtx.gadt.addToConstraint(typeParams) + } def rhsType = typedAheadExpr(mdef.rhs, (inherited orElse rhsProto).widenExpr)(rhsCtx).tpe // Approximate a type `tp` with a type that does not contain skolem types. diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index f07563ec49f3..6faac5b0614b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1047,37 +1047,8 @@ class Typer extends Namer assignType(cpy.Match(tree)(sel, cases1), sel, cases1) } - /** gadtSyms = "all type parameters of enclosing methods that appear - * non-variantly in the selector type" todo: should typevars - * which appear with variances +1 and -1 (in different - * places) be considered as well? - */ - def gadtSyms(selType: Type)(implicit ctx: Context): Set[Symbol] = trace(i"GADT syms of $selType", gadts) { - val accu = new TypeAccumulator[Set[Symbol]] { - def apply(tsyms: Set[Symbol], t: Type): Set[Symbol] = { - val tsyms1 = t match { - case tr: TypeRef if (tr.symbol is TypeParam) && tr.symbol.owner.isTerm && variance == 0 => - tsyms + tr.symbol - case _ => - tsyms - } - foldOver(tsyms1, t) - } - } - accu(Set.empty, selType) - } - - /** Context with fresh GADT bounds for all gadtSyms */ - def gadtContext(gadtSyms: Set[Symbol])(implicit ctx: Context): Context = { - val gadtCtx = ctx.fresh.setFreshGADTBounds - for (sym <- gadtSyms) - if (!gadtCtx.gadt.contains(sym)) gadtCtx.gadt.addEmptyBounds(sym) - gadtCtx - } - def typedCases(cases: List[untpd.CaseDef], selType: Type, pt: Type)(implicit ctx: Context): List[CaseDef] = { - val gadts = gadtSyms(selType) - cases.mapconserve(typedCase(_, selType, pt, gadts)) + cases.mapconserve(typedCase(_, selType, pt)) } /** - strip all instantiated TypeVars from pattern types. @@ -1096,7 +1067,7 @@ class Typer extends Namer if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(sym) else ctx.error(new DuplicateBind(b, cdef), b.sourcePos) if (!ctx.isAfterTyper) { - val bounds = ctx.gadt.bounds(sym) + val bounds = ctx.gadt.fullBounds(sym) if (bounds != null) sym.info = bounds } b @@ -1105,9 +1076,9 @@ class Typer extends Namer } /** Type a case. */ - def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") { + def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = track("typedCase") { val originalCtx = ctx - val gadtCtx = gadtContext(gadtSyms) + val gadtCtx: Context = ctx.fresh.setFreshGADTBounds def caseRest(pat: Tree)(implicit ctx: Context) = { val pat1 = indexPattern(tree).transform(pat) @@ -1132,8 +1103,6 @@ class Typer extends Namer def typedTypeCase(cdef: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = { def caseRest(implicit ctx: Context) = { val pat1 = checkSimpleKinded(typedType(cdef.pat)(ctx.addMode(Mode.Pattern))) - if (!ctx.isAfterTyper) - constrainPatternType(pat1.tpe, selType)(ctx.addMode(Mode.GADTflexible)) val pat2 = indexPattern(cdef).transform(pat1) val body1 = typedType(cdef.body, pt) assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1) @@ -1537,19 +1506,28 @@ class Typer extends Namer if (sym is ImplicitOrImplied) checkImplicitConversionDefOK(sym) val tpt1 = checkSimpleKinded(typedType(tpt)) - var rhsCtx = ctx - if (sym.isConstructor && !sym.isPrimaryConstructor && tparams1.nonEmpty) { - // for secondary constructors we need a context that "knows" - // that their type parameters are aliases of the class type parameters. - // See pos/i941.scala - rhsCtx = ctx.fresh.setFreshGADTBounds - (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => - val tr = tparam.typeRef - rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false) - rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true) + val rhsCtx = ctx.fresh + if (tparams1.nonEmpty) { + rhsCtx.setFreshGADTBounds + if (!sym.isConstructor) { + // we're typing a polymorphic definition's body, + // so we allow constraining all of its type parameters + // constructors are an exception as we don't allow constraining type params of classes + rhsCtx.gadt.addToConstraint(tparams1.map(_.symbol)) + } else if (!sym.isPrimaryConstructor) { + // otherwise, for secondary constructors we need a context that "knows" + // that their type parameters are aliases of the class type parameters. + // See pos/i941.scala + rhsCtx.gadt.addToConstraint(tparams1.map(_.symbol)) + (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => + val tr = tparam.typeRef + rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false) + rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true) + } } } - if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody) + + if (sym.isInlineMethod) rhsCtx.addMode(Mode.InlineableBody) val rhs1 = typedExpr(ddef.rhs, tpt1.tpe.widenExpr)(rhsCtx) if (sym.isInlineMethod) { diff --git a/tests/neg/classOf.check b/tests/neg/classOf.check index 7c761b8af5be..b8416e4007e3 100644 --- a/tests/neg/classOf.check +++ b/tests/neg/classOf.check @@ -2,5 +2,7 @@ Test.C{I = String} is not a class type [116..117] in classOf.scala T is not a class type + +where: T is a type in method f2 with bounds <: String [72..73] in classOf.scala T is not a class type diff --git a/tests/neg/creative-gadt-constraints.scala b/tests/neg/creative-gadt-constraints.scala new file mode 100644 index 000000000000..a8869de5f75d --- /dev/null +++ b/tests/neg/creative-gadt-constraints.scala @@ -0,0 +1,66 @@ +object buffer { + object EssaInt { + def unapply(i: Int): Some[Int] = Some(i) + } + + case class Inv[T](t: T) + + enum EQ[A, B] { case Refl[T]() extends EQ[T, T] } + enum SUB[A, +B] { case Refl[T]() extends SUB[T, T] } // A <: B + + def test_eq1[A, B](eq: EQ[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) | Sko(Int) + eq match { case EQ.Refl() => // a = b + val success: A = b + val fail: A = 0 // error + 0 // error + } + } + } + + def test_eq2[A, B](eq: EQ[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) + eq match { case EQ.Refl() => // a = b + val success: A = b + val fail: A = 0 // error + 0 // error + } + } + } + + def test_sub1[A, B](sub: SUB[A, B], a: A, b: B): B = + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int) + sub match { case SUB.Refl() => // b >: a + val success: B = a + val fail: A = 0 // error + 0 // error + } + } + } + + def test_sub2[A, B](sub: SUB[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int) + sub match { case SUB.Refl() => // b >: a + val success: B = a + val fail: A = 0 // error + 0 // error + } + } + } + + + def test_sub_eq[A, B, C](sub: SUB[A|B, C], eqA: EQ[A, 5], eqB: EQ[B, 6]): C = + sub match { case SUB.Refl() => // C >: A | B + eqA match { case EQ.Refl() => // A = 5 + eqB match { case EQ.Refl() => // B = 6 + val fail1: A = 0 // error + val fail2: B = 0 // error + 0 // error + } + } + } +} diff --git a/tests/neg/gadt-alias-injectivity.scala b/tests/neg/gadt-alias-injectivity.scala new file mode 100644 index 000000000000..34203a45f4f8 --- /dev/null +++ b/tests/neg/gadt-alias-injectivity.scala @@ -0,0 +1,48 @@ +object Test { + enum EQ[A, B] { + case Refl[T]() extends EQ[T, T] + } + import EQ._ + + object A { + type Foo[+X] = (X, X) + def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match { + case Refl() => x + } + } + + object B { + type Foo[X] = (X, X) + def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match { + case Refl() => x + } + } + + object C { + type Foo[+X] = Int | (X, X) + def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match { + case Refl() => x + } + } + + object D { + type Foo[+X] = (Int, Int) + def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match { + case Refl() => x // error + } + } + + trait E { + type Foo[+X] <: Int | (X, X) + def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match { + case Refl() => x // error + } + } + + trait F { + type Foo[X] >: Int | (X, X) + def foo[X, Y](x: X, eq: EQ[Foo[X], Foo[Y]]): Y = eq match { + case Refl() => x // error + } + } +} diff --git a/tests/neg/gadt-no-approx.scala b/tests/neg/gadt-no-approx.scala new file mode 100644 index 000000000000..eef0d82cba21 --- /dev/null +++ b/tests/neg/gadt-no-approx.scala @@ -0,0 +1,10 @@ +object `gadt-no-approx` { + def fo[U](u: U): U = + (0 : Int) match { + case _: u.type => + val i: Int = (??? : U) // error + // potentially could compile + // val i2: Int = u + u + } +} diff --git a/tests/neg/gadt-uninjectivity.scala b/tests/neg/gadt-uninjectivity.scala index f1d0fc59000a..30ebac32b735 100644 --- a/tests/neg/gadt-uninjectivity.scala +++ b/tests/neg/gadt-uninjectivity.scala @@ -4,7 +4,7 @@ object uninjectivity { def absurd1[F[_], X, Y](eq: EQ[F[X], F[Y]], x: X): Y = eq match { case Refl() => - x // should be an error + x // error } def absurd2[F[_], G[_]](eq: EQ[F[Int], G[Int]], fi: F[Int], fs: F[String]): G[Int] = eq match { diff --git a/tests/neg/int-extractor.scala b/tests/neg/int-extractor.scala new file mode 100644 index 000000000000..8534c5a1bc00 --- /dev/null +++ b/tests/neg/int-extractor.scala @@ -0,0 +1,31 @@ +object Test { + object EssaInt { + def unapply(i: Int): Some[Int] = Some(i) + } + + def foo1[T](t: T): T = t match { + case EssaInt(_) => + 0 // error + } + + def foo2[T](t: T): T = t match { + case EssaInt(_) => t match { + case EssaInt(_) => + 0 // error + } + } + + case class Inv[T](t: T) + + def bar1[T](t: T): T = Inv(t) match { + case Inv(EssaInt(_)) => + 0 // error + } + + def bar2[T](t: T): T = t match { + case Inv(EssaInt(_)) => t match { + case Inv(EssaInt(_)) => + 0 // error + } + } +} diff --git a/tests/neg/invariant-gadt.scala b/tests/neg/invariant-gadt.scala new file mode 100644 index 000000000000..ac335f57743f --- /dev/null +++ b/tests/neg/invariant-gadt.scala @@ -0,0 +1,27 @@ +object `invariant-gadt` { + case class Invariant[T](value: T) + + def unsound0[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => + (0: Any) // error + } + + def unsound1[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => + 0 // error + } + + def unsound2[T](t: T): T = Invariant(t) match { + case Invariant(value) => value match { + case _: Int => + 0 // error + } + } + + def unsoundTwice[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => Invariant(t) match { + case Invariant(_: Int) => + 0 // error + } + } +} diff --git a/tests/neg/typeclass-derivation2.scala b/tests/neg/typeclass-derivation2.scala index 33c64494e9c5..ddb6517fb869 100644 --- a/tests/neg/typeclass-derivation2.scala +++ b/tests/neg/typeclass-derivation2.scala @@ -111,6 +111,13 @@ object TypeLevel { * It informs that type `T` has shape `S` and also implements runtime reflection on `T`. */ abstract class Shaped[T, S <: Shape] extends Reflected[T] + + // substitute for erasedValue that allows precise matching + final abstract class Type[-A, +B] + type Subtype[t] = Type[_, t] + type Supertype[t] = Type[t, _] + type Exactly[t] = Type[t, t] + erased def typeOf[T]: Type[T, T] = ??? } // An algebraic datatype @@ -203,7 +210,7 @@ trait Show[T] { def show(x: T): String } object Show { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryShow[T](x: T): String = implicit match { @@ -229,9 +236,14 @@ object Show { inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => showCase[T, elems](r, x) - case _ => showCases[T, alts1](r, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => showCase[T, elems](r, x) + case _ => showCases[T, alts1](r, x) + } + case _ => + error("invalid call to showCases: one of Alts is not a subtype of T") } case _: Unit => throw new MatchError(x) diff --git a/tests/pos/gadt-EQK.scala b/tests/pos/gadt-EQK.scala index b713de2d833b..0c1fbfe03f81 100644 --- a/tests/pos/gadt-EQK.scala +++ b/tests/pos/gadt-EQK.scala @@ -18,16 +18,4 @@ object EQK { fa : G[Int] } } - - def m2[F[_], G[_], A](fa: F[A], a: A, eq: EQ[F[A], G[Int]], eqk: EQK[F, G]): Int = - eqk match { - case ReflK() => eq match { - case Refl() => - val r1: F[Int] = fa - val r2: G[A] = fa - val r3: F[Int] = r2 - a - } - } - } diff --git a/tests/pos/gadt-accumulatable.scala b/tests/pos/gadt-accumulatable.scala new file mode 100644 index 000000000000..ce4cf347538d --- /dev/null +++ b/tests/pos/gadt-accumulatable.scala @@ -0,0 +1,37 @@ +object `gadt-accumulatable` { + sealed abstract class Or[+G,+B] extends Product with Serializable + final case class Good[+G](g: G) extends Or[G,Nothing] + final case class Bad[+B](b: B) extends Or[Nothing,B] + + sealed trait Validation[+E] extends Product with Serializable + case object Pass extends Validation[Nothing] + case class Fail[E](error: E) extends Validation[E] + + sealed abstract class Every[+T] protected (underlying: Vector[T]) extends /*PartialFunction[Int, T] with*/ Product with Serializable + final case class One[+T](loneElement: T) extends Every[T](Vector(loneElement)) + final case class Many[+T](firstElement: T, secondElement: T, otherElements: T*) extends Every[T](firstElement +: secondElement +: Vector(otherElements: _*)) + + class Accumulatable[G, ERR, EVERY[_]] { } + + def convertOrToAccumulatable[G, ERR, EVERY[b] <: Every[b]](accumulatable: G Or EVERY[ERR]): Accumulatable[G, ERR, EVERY] = { + new Accumulatable[G, ERR, EVERY] { + def when[OTHERERR >: ERR](validations: (G => Validation[OTHERERR])*): G Or Every[OTHERERR] = { + accumulatable match { + case Good(g) => + val results = validations flatMap (_(g) match { case Fail(x) => val z: OTHERERR = x; Seq(x); case Pass => Seq.empty}) + results.length match { + case 0 => Good(g) + case 1 => Bad(One(results.head)) + case _ => + val first = results.head + val tail = results.tail + val second = tail.head + val rest = tail.tail + Bad(Many(first, second, rest: _*)) + } + case Bad(myBad) => Bad(myBad) + } + } + } + } +} diff --git a/tests/pos/gadt-all-params.scala b/tests/pos/gadt-all-params.scala new file mode 100644 index 000000000000..b5d7baecc283 --- /dev/null +++ b/tests/pos/gadt-all-params.scala @@ -0,0 +1,9 @@ +object `gadt-all-params` { + enum Expr[T] { + case UnitLit extends Expr[Unit] + } + + def foo[T >: TT <: TT, TT](e: Expr[T]): T = e match { + case Expr.UnitLit => () + } +} diff --git a/tests/pos/gadt-inference.scala b/tests/pos/gadt-inference.scala new file mode 100644 index 000000000000..e625e4823dc0 --- /dev/null +++ b/tests/pos/gadt-inference.scala @@ -0,0 +1,44 @@ +object `gadt-inference` { + enum Expr[T] { + case StrLit(s: String) extends Expr[String] + case IntLit(i: Int) extends Expr[Int] + } + import Expr._ + + def eval[T](e: Expr[T]) = + e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + + def nested[T](o: Option[Expr[T]]) = + o match { + case Some(e) => e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + case None => ??? + } + + def local[T](e: Expr[T]) = { + def eval[T](e: Expr[T]) = + e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + + eval(e) : T + } +} diff --git a/tests/pos/precise-pattern-type.scala b/tests/pos/precise-pattern-type.scala new file mode 100644 index 000000000000..856672fafbf2 --- /dev/null +++ b/tests/pos/precise-pattern-type.scala @@ -0,0 +1,16 @@ +object `precise-pattern-type` { + class Type { + def isType: Boolean = true + } + + class Tree[-T >: Null] { + def tpe: T @annotation.unchecked.uncheckedVariance = ??? + } + + case class Select[-T >: Null](qual: Tree[T]) extends Tree[T] + + def test[T <: Tree[Type]](tree: T) = tree match { + case Select(q) => + q.tpe.isType + } +} diff --git a/tests/run-macros/tasty-extractors-3.check b/tests/run-macros/tasty-extractors-3.check index 2e3b9f23e983..35c88a7598f5 100644 --- a/tests/run-macros/tasty-extractors-3.check +++ b/tests/run-macros/tasty-extractors-3.check @@ -10,6 +10,8 @@ Type.SymRef(IsClassDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDe Type.SymRef(IsTypeDefSymbol(), NoPrefix()) +Type.SymRef(IsTypeDefSymbol(), NoPrefix()) + TypeBounds(Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix())))), Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix()))))) Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix())))) diff --git a/tests/run/gadt-injectivity-unsoundness.scala b/tests/run/gadt-injectivity-unsoundness.scala deleted file mode 100644 index 192a82afb539..000000000000 --- a/tests/run/gadt-injectivity-unsoundness.scala +++ /dev/null @@ -1,19 +0,0 @@ -object Test { - sealed trait EQ[A, B] - final case class Refl[T]() extends EQ[T, T] - - def absurd[F[_], X, Y](eq: EQ[F[X], F[Y]], x: X): Y = eq match { - case Refl() => x - } - - var ex: Exception = _ - try { - type Unsoundness[X] = Int - val s: String = absurd[Unsoundness, Int, String](Refl(), 0) - } catch { - case e: ClassCastException => ex = e - } - - def main(args: Array[String]) = - assert(ex != null) -} diff --git a/tests/run/typeclass-derivation2.scala b/tests/run/typeclass-derivation2.scala index 8ac7cec4487c..f8812b461d48 100644 --- a/tests/run/typeclass-derivation2.scala +++ b/tests/run/typeclass-derivation2.scala @@ -113,6 +113,13 @@ object TypeLevel { * It informs that type `T` has shape `S` and also implements runtime reflection on `T`. */ abstract class Shaped[T, S <: Shape] extends Reflected[T] + + // substitute for erasedValue that allows precise matching + final abstract class Type[-A, +B] + type Subtype[t] = Type[_, t] + type Supertype[t] = Type[t, _] + type Exactly[t] = Type[t, t] + erased def typeOf[T]: Type[T, T] = ??? } // An algebraic datatype @@ -217,7 +224,7 @@ trait Eq[T] { } object Eq { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryEql[T](x: T, y: T) = implicit match { @@ -239,8 +246,13 @@ object Eq { inline def eqlCases[T, Alts <: Tuple](xm: Mirror, ym: Mirror, ordinal: Int, n: Int): Boolean = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - if (n == ordinal) eqlElems[elems](xm, ym, 0) - else eqlCases[T, alts1](xm, ym, ordinal, n + 1) + inline typeOf[alt] match { + case _: Subtype[T] => + if (n == ordinal) eqlElems[elems](xm, ym, 0) + else eqlCases[T, alts1](xm, ym, ordinal, n + 1) + case _ => + error("invalid call to eqlCases: one of Alts is not a subtype of T") + } case _: Unit => false } @@ -271,7 +283,7 @@ trait Pickler[T] { } object Pickler { - import scala.compiletime.{erasedValue, constValue} + import scala.compiletime.{erasedValue, constValue, error} import TypeLevel._ def nextInt(buf: mutable.ListBuffer[Int]): Int = try buf.head finally buf.trimStart(1) @@ -294,12 +306,17 @@ object Pickler { inline def pickleCases[T, Alts <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], x: T, n: Int): Unit = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => - buf += n - pickleCase[T, elems](r, buf, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => + buf += n + pickleCase[T, elems](r, buf, x) + case _ => + pickleCases[T, alts1](r, buf, x, n + 1) + } case _ => - pickleCases[T, alts1](r, buf, x, n + 1) + error("invalid pickleCases call: one of Alts is not a subtype of T") } case _: Unit => } @@ -362,7 +379,7 @@ trait Show[T] { def show(x: T): String } object Show { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryShow[T](x: T): String = implicit match { @@ -388,9 +405,15 @@ object Show { inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => showCase[T, elems](r, x) - case _ => showCases[T, alts1](r, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => + showCase[T, elems](r, x) + case _ => showCases[T, alts1](r, x) + } + case _ => + error("invalid call to showCases: one of Alts is not a subtype of T") } case _: Unit => throw new MatchError(x)