Skip to content

Commit 1440699

Browse files
authored
Merge pull request #12317 from dotty-staging/fix-12306-v2
Treat Refinements more like AndTypes
2 parents e3e8154 + fbf8949 commit 1440699

File tree

2 files changed

+97
-57
lines changed

2 files changed

+97
-57
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 74 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
169169
private inline def inFrozenGadtAndConstraint[T](inline op: T): T =
170170
inFrozenGadtIf(true)(inFrozenConstraint(op))
171171

172+
extension (sym: Symbol)
173+
private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean =
174+
val bounds = gadtBounds(sym)
175+
bounds != null && op(bounds)
176+
172177
protected def isSubType(tp1: Type, tp2: Type, a: ApproxState): Boolean = {
173178
val savedApprox = approx
174179
val savedLeftRoot = leftRoot
@@ -465,19 +470,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
465470
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
466471
case _ => true
467472

468-
// An & on the left side loses information. We compensate by also trying the join.
469-
// This is less ad-hoc than it looks since we produce joins in type inference,
470-
// and then need to check that they are indeed supertypes of the original types
471-
// under -Ycheck. Test case is i7965.scala.
472-
def containsAnd(tp: Type): Boolean = tp.dealiasKeepRefiningAnnots match
473-
case tp: AndType => true
474-
case OrType(tp1, tp2) => containsAnd(tp1) || containsAnd(tp2)
475-
case _ => false
476-
477473
widenOK
478474
|| joinOK
479475
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
480476
|| containsAnd(tp1) && inFrozenGadt(recur(tp1.join, tp2))
477+
// An & on the left side loses information. We compensate by also trying the join.
478+
// This is less ad-hoc than it looks since we produce joins in type inference,
479+
// and then need to check that they are indeed supertypes of the original types
480+
// under -Ycheck. Test case is i7965.scala.
481+
481482
case tp1: MatchType =>
482483
val reduced = tp1.reduced
483484
if (reduced.exists) recur(reduced, tp2) else thirdTry
@@ -489,11 +490,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
489490

490491
def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match {
491492
case info2: TypeBounds =>
492-
def compareGADT: Boolean = {
493-
val gbounds2 = gadtBounds(tp2.symbol)
494-
(gbounds2 != null) &&
495-
(isSubTypeWhenFrozen(tp1, gbounds2.lo) ||
496-
(tp1 match {
493+
def compareGADT: Boolean =
494+
tp2.symbol.onGadtBounds(gbounds2 =>
495+
isSubTypeWhenFrozen(tp1, gbounds2.lo)
496+
|| tp1.match
497497
case tp1: NamedType if ctx.gadt.contains(tp1.symbol) =>
498498
// Note: since we approximate constrained types only with their non-param bounds,
499499
// we need to manually handle the case when we're comparing two constrained types,
@@ -502,10 +502,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
502502
// comparing two constrained types, and that case will be handled here first.
503503
ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol)
504504
case _ => false
505-
}) ||
506-
narrowGADTBounds(tp2, tp1, approx, isUpper = false)) &&
507-
{ isBottom(tp1) || GADTusage(tp2.symbol) }
508-
}
505+
|| narrowGADTBounds(tp2, tp1, approx, isUpper = false))
506+
&& (isBottom(tp1) || GADTusage(tp2.symbol))
507+
509508
isSubApproxHi(tp1, info2.lo) || compareGADT || tryLiftedToThis2 || fourthTry
510509

511510
case _ =>
@@ -559,31 +558,35 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
559558
case tp2: TypeParamRef =>
560559
compareTypeParamRef(tp2)
561560
case tp2: RefinedType =>
562-
def compareRefinedSlow: Boolean = {
561+
def compareRefinedSlow: Boolean =
563562
val name2 = tp2.refinedName
564-
recur(tp1, tp2.parent) &&
565-
(name2 == nme.WILDCARD || hasMatchingMember(name2, tp1, tp2))
566-
}
567-
def compareRefined: Boolean = {
563+
recur(tp1, tp2.parent)
564+
&& (name2 == nme.WILDCARD || hasMatchingMember(name2, tp1, tp2))
565+
566+
def compareRefined: Boolean =
568567
val tp1w = tp1.widen
569568
val skipped2 = skipMatching(tp1w, tp2)
570-
if ((skipped2 eq tp2) || !Config.fastPathForRefinedSubtype)
571-
tp1 match {
572-
case tp1: AndType =>
573-
// Delay calling `compareRefinedSlow` because looking up a member
574-
// of an `AndType` can lead to a cascade of subtyping checks
575-
// This twist is needed to make collection/generic/ParFactory.scala compile
576-
fourthTry || compareRefinedSlow
577-
case tp1: HKTypeLambda =>
578-
// HKTypeLambdas do not have members.
579-
fourthTry
580-
case _ =>
581-
compareRefinedSlow || fourthTry
582-
}
569+
if (skipped2 eq tp2) || !Config.fastPathForRefinedSubtype then
570+
if containsAnd(tp1) then
571+
tp2.parent match
572+
case _: RefinedType | _: AndType =>
573+
// maximally decompose RHS to limit the bad effects of the `either` that is necessary
574+
// since LHS contains an AndType
575+
recur(tp1, decomposeRefinements(tp2, Nil))
576+
case _ =>
577+
// Delay calling `compareRefinedSlow` because looking up a member
578+
// of an `AndType` can lead to a cascade of subtyping checks
579+
// This twist is needed to make collection/generic/ParFactory.scala compile
580+
fourthTry || compareRefinedSlow
581+
else if tp1.isInstanceOf[HKTypeLambda] then
582+
// HKTypeLambdas do not have members.
583+
fourthTry
584+
else
585+
compareRefinedSlow || fourthTry
583586
else // fast path, in particular for refinements resulting from parameterization.
584587
isSubRefinements(tp1w.asInstanceOf[RefinedType], tp2, skipped2) &&
585588
recur(tp1, skipped2)
586-
}
589+
587590
compareRefined
588591
case tp2: RecType =>
589592
def compareRec = tp1.safeDealias match {
@@ -751,13 +754,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
751754
case tp1: TypeRef =>
752755
tp1.info match {
753756
case TypeBounds(_, hi1) =>
754-
def compareGADT = {
755-
val gbounds1 = gadtBounds(tp1.symbol)
756-
(gbounds1 != null) &&
757-
(isSubTypeWhenFrozen(gbounds1.hi, tp2) ||
758-
narrowGADTBounds(tp1, tp2, approx, isUpper = true)) &&
759-
{ tp2.isAny || GADTusage(tp1.symbol) }
760-
}
757+
def compareGADT =
758+
tp1.symbol.onGadtBounds(gbounds1 =>
759+
isSubTypeWhenFrozen(gbounds1.hi, tp2)
760+
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true))
761+
&& (tp2.isAny || GADTusage(tp1.symbol))
762+
761763
isSubType(hi1, tp2, approx.addLow) || compareGADT || tryLiftedToThis1
762764
case _ =>
763765
def isNullable(tp: Type): Boolean = tp.widenDealias match {
@@ -1033,17 +1035,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
10331035

10341036
var touchedGADTs = false
10351037
var gadtIsInstantiated = false
1036-
def byGadtBounds(sym: Symbol, tp: Type, fromAbove: Boolean): Boolean = {
1037-
touchedGADTs = true
1038-
val b = gadtBounds(sym)
1039-
def boundsDescr = if b == null then "null" else b.show
1040-
b != null && inFrozenGadt {
1041-
if fromAbove then isSubType(b.hi, tp) else isSubType(tp, b.lo)
1042-
} && {
1043-
gadtIsInstantiated = b.isInstanceOf[TypeAlias]
1044-
true
1045-
}
1046-
}
1038+
1039+
extension (sym: Symbol)
1040+
inline def byGadtBounds(inline op: TypeBounds => Boolean): Boolean =
1041+
touchedGADTs = true
1042+
sym.onGadtBounds(
1043+
b => op(b) && { gadtIsInstantiated = b.isInstanceOf[TypeAlias]; true })
10471044

10481045
def byGadtOrdering: Boolean =
10491046
ctx.gadt.contains(tycon1sym)
@@ -1052,8 +1049,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
10521049

10531050
val res = (
10541051
tycon1sym == tycon2sym && isSubPrefix(tycon1.prefix, tycon2.prefix)
1055-
|| byGadtBounds(tycon1sym, tycon2, fromAbove = true)
1056-
|| byGadtBounds(tycon2sym, tycon1, fromAbove = false)
1052+
|| tycon1sym.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2))
1053+
|| tycon2sym.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo))
10571054
|| byGadtOrdering
10581055
) && {
10591056
// There are two cases in which we can assume injectivity.
@@ -1691,6 +1688,26 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
16911688
else op2
16921689
end necessaryEither
16931690

1691+
/** Decompose into conjunction of types each of which has only a single refinement */
1692+
def decomposeRefinements(tp: Type, refines: List[(Name, Type)]): Type = tp match
1693+
case RefinedType(parent, rname, rinfo) =>
1694+
decomposeRefinements(parent, (rname, rinfo) :: refines)
1695+
case AndType(tp1, tp2) =>
1696+
AndType(decomposeRefinements(tp1, refines), decomposeRefinements(tp2, refines))
1697+
case _ =>
1698+
refines.map(RefinedType(tp, _, _): Type).reduce(AndType(_, _))
1699+
1700+
/** Can comparing this type on the left lead to an either? This is the case if
1701+
* the type is and AndType or contains embedded occurrences of AndTypes
1702+
*/
1703+
def containsAnd(tp: Type): Boolean = tp match
1704+
case tp: AndType => true
1705+
case OrType(tp1, tp2) => containsAnd(tp1) || containsAnd(tp2)
1706+
case tp: TypeParamRef => containsAnd(bounds(tp).hi)
1707+
case tp: TypeRef => containsAnd(tp.info.hiBound) || tp.symbol.onGadtBounds(gbounds => containsAnd(gbounds.hi))
1708+
case tp: TypeProxy => containsAnd(tp.superType)
1709+
case _ => false
1710+
16941711
/** Does type `tp1` have a member with name `name` whose normalized type is a subtype of
16951712
* the normalized type of the refinement `tp2`?
16961713
* Normalization is as follows: If `tp2` contains a skolem to its refinement type,

tests/pos/i12306.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
class Record(elems: Map[String, Any]) extends Selectable:
2+
val fields = elems.toMap
3+
def selectDynamic(name: String): Any = fields(name)
4+
object Record:
5+
def apply(elems: Map[String, Any]): Record = new Record(elems)
6+
extension [A <: Record] (a:A) {
7+
def join[B <: Record] (b:B): A & B = {
8+
Record(a.fields ++ b.fields).asInstanceOf[A & B]
9+
}
10+
}
11+
12+
type Person = Record { val name: String; val age: Int }
13+
type Child = Record { val parent: String }
14+
type PersonAndChild = Record { val name: String; val age: Int; val parent: String }
15+
16+
@main def hello = {
17+
val person = Record(Map("name" -> "Emma", "age" -> 42)).asInstanceOf[Person]
18+
val child = Record(Map("parent" -> "Alice")).asInstanceOf[Child]
19+
val personAndChild = person.join(child)
20+
21+
val v1: PersonAndChild = personAndChild
22+
val v2: PersonAndChild = person.join(child)
23+
}

0 commit comments

Comments
 (0)