Skip to content

Commit b537220

Browse files
authored
Merge pull request #6050 from dotty-staging/fix-#6047
Fix #6047: Implement variance rules for match types
2 parents a6bab2c + 33d3622 commit b537220

File tree

19 files changed

+234
-97
lines changed

19 files changed

+234
-97
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,9 @@ class Definitions {
761761

762762
lazy val TypeBox_CAP: TypeSymbol = TypeBoxType.symbol.requiredType(tpnme.CAP)
763763

764+
lazy val MatchCaseType: TypeRef = ctx.requiredClassRef("scala.internal.MatchCase")
765+
def MatchCaseClass(implicit ctx: Context): ClassSymbol = MatchCaseType.symbol.asClass
766+
764767
lazy val NotType: TypeRef = ctx.requiredClassRef("scala.implicits.Not")
765768
def NotClass(implicit ctx: Context): ClassSymbol = NotType.symbol.asClass
766769
def NotModule(implicit ctx: Context): Symbol = NotClass.companionModule
@@ -933,6 +936,23 @@ class Definitions {
933936
}
934937
}
935938

939+
object MatchCase {
940+
def apply(pat: Type, body: Type)(implicit ctx: Context): Type =
941+
MatchCaseType.appliedTo(pat, body)
942+
def unapply(tp: Type)(implicit ctx: Context): Option[(Type, Type)] = tp match {
943+
case AppliedType(tycon, pat :: body :: Nil) if tycon.isRef(MatchCaseClass) =>
944+
Some((pat, body))
945+
case _ =>
946+
None
947+
}
948+
def isInstance(tp: Type)(implicit ctx: Context): Boolean = tp match {
949+
case AppliedType(tycon: TypeRef, _) =>
950+
tycon.name == tpnme.MatchCase && // necessary pre-filter to avoid forcing symbols
951+
tycon.isRef(MatchCaseClass)
952+
case _ => false
953+
}
954+
}
955+
936956
/** An extractor for multi-dimensional arrays.
937957
* Note that this will also extract the high bound if an
938958
* element type is a wildcard. E.g.

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ object StdNames {
337337
val Literal: N = "Literal"
338338
val LiteralAnnotArg: N = "LiteralAnnotArg"
339339
val longHash: N = "longHash"
340+
val MatchCase: N = "MatchCase"
340341
val Modifiers: N = "Modifiers"
341342
val NestedAnnotArg: N = "NestedAnnotArg"
342343
val NoFlags: N = "NoFlags"

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,45 +2134,46 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
21342134
}
21352135
}
21362136

2137-
var result: Type = NoType
2138-
var remainingCases = cases
2139-
while (!remainingCases.isEmpty) {
2140-
val (cas :: cass) = remainingCases
2141-
remainingCases = cass
2142-
val saved = constraint
2143-
try {
2144-
inFrozenConstraint {
2145-
val cas1 = cas match {
2146-
case cas: HKTypeLambda =>
2147-
caseLambda = constrained(cas)
2148-
caseLambda.resultType
2137+
/** Match a single case.
2138+
* @return Some(tp) if the match succeeds with type `tp`
2139+
* Some(NoType) if the match fails, and there is an overlap between pattern and scrutinee
2140+
* None if the match fails and we should consider the following cases
2141+
* because scrutinee and pattern do not overlap
2142+
*/
2143+
def matchCase(cas: Type): Option[Type] = {
2144+
val cas1 = cas match {
2145+
case cas: HKTypeLambda =>
2146+
caseLambda = constrained(cas)
2147+
caseLambda.resultType
2148+
case _ =>
2149+
cas
2150+
}
2151+
val defn.MatchCase(pat, body) = cas1
2152+
if (isSubType(scrut, pat))
2153+
// `scrut` is a subtype of `pat`: *It's a Match!*
2154+
Some {
2155+
caseLambda match {
2156+
case caseLambda: HKTypeLambda =>
2157+
val instances = paramInstances(new Array(caseLambda.paramNames.length), pat)
2158+
instantiateParams(instances)(body)
21492159
case _ =>
2150-
cas
2151-
}
2152-
val defn.FunctionOf(pat :: Nil, body, _, _) = cas1
2153-
if (isSubType(scrut, pat)) {
2154-
// `scrut` is a subtype of `pat`: *It's a Match!*
2155-
result = caseLambda match {
2156-
case caseLambda: HKTypeLambda =>
2157-
val instances = paramInstances(new Array(caseLambda.paramNames.length), pat)
2158-
instantiateParams(instances)(body)
2159-
case _ =>
2160-
body
2161-
}
2162-
remainingCases = Nil
2163-
} else if (!intersecting(scrut, pat)) {
2164-
// We found a proof that `scrut` and `pat` are incompatible.
2165-
// The search continues.
2166-
} else {
2167-
// We are stuck: this match type instanciation is irreducible.
2168-
result = NoType
2169-
remainingCases = Nil
2160+
body
21702161
}
21712162
}
2172-
}
2173-
finally constraint = saved
2163+
else if (intersecting(scrut, pat))
2164+
Some(NoType)
2165+
else
2166+
// We found a proof that `scrut` and `pat` are incompatible.
2167+
// The search continues.
2168+
None
21742169
}
2175-
result
2170+
2171+
def recur(cases: List[Type]): Type = cases match {
2172+
case cas :: cases1 => matchCase(cas).getOrElse(recur(cases1))
2173+
case Nil => NoType
2174+
}
2175+
2176+
inFrozenConstraint(recur(cases))
21762177
}
21772178
}
21782179

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3756,7 +3756,7 @@ object Types {
37563756

37573757
def caseType(tp: Type)(implicit ctx: Context): Type = tp match {
37583758
case tp: HKTypeLambda => caseType(tp.resType)
3759-
case defn.FunctionOf(_, restpe, _, _) => restpe
3759+
case defn.MatchCase(_, body) => body
37603760
}
37613761

37623762
def alternatives(implicit ctx: Context): List[Type] = cases.map(caseType)
@@ -4417,10 +4417,12 @@ object Types {
44174417

44184418
case tp: LambdaType =>
44194419
def mapOverLambda = {
4420-
variance = -variance
4420+
val restpe = tp.resultType
4421+
val saved = variance
4422+
variance = if (defn.MatchCase.isInstance(restpe)) 0 else -variance
44214423
val ptypes1 = tp.paramInfos.mapConserve(this).asInstanceOf[List[tp.PInfo]]
4422-
variance = -variance
4423-
derivedLambdaType(tp)(ptypes1, this(tp.resultType))
4424+
variance = saved
4425+
derivedLambdaType(tp)(ptypes1, this(restpe))
44244426
}
44254427
mapOverLambda
44264428

@@ -4440,7 +4442,9 @@ object Types {
44404442
derivedOrType(tp, this(tp.tp1), this(tp.tp2))
44414443

44424444
case tp: MatchType =>
4443-
derivedMatchType(tp, this(tp.bound), this(tp.scrutinee), tp.cases.mapConserve(this))
4445+
val bound1 = this(tp.bound)
4446+
val scrut1 = atVariance(0)(this(tp.scrutinee))
4447+
derivedMatchType(tp, bound1, scrut1, tp.cases.mapConserve(this))
44444448

44454449
case tp: SkolemType =>
44464450
derivedSkolemType(tp, this(tp.info))
@@ -4804,10 +4808,12 @@ object Types {
48044808
case _: BoundType | _: ThisType => x
48054809

48064810
case tp: LambdaType =>
4807-
variance = -variance
4811+
val restpe = tp.resultType
4812+
val saved = variance
4813+
variance = if (defn.MatchCase.isInstance(restpe)) 0 else -variance
48084814
val y = foldOver(x, tp.paramInfos)
4809-
variance = -variance
4810-
this(y, tp.resultType)
4815+
variance = saved
4816+
this(y, restpe)
48114817

48124818
case tp: TermRef =>
48134819
if (stopAtStatic && tp.currentSymbol.isStatic || (tp.prefix `eq` NoPrefix)) x
@@ -4835,7 +4841,9 @@ object Types {
48354841
this(this(x, tp.tp1), tp.tp2)
48364842

48374843
case tp: MatchType =>
4838-
foldOver(this(this(x, tp.bound), tp.scrutinee), tp.cases)
4844+
val x1 = this(x, tp.bound)
4845+
val x2 = atVariance(0)(this(x1, tp.scrutinee))
4846+
foldOver(x2, tp.cases)
48394847

48404848
case AnnotatedType(underlying, annot) =>
48414849
this(applyToAnnot(x, annot), underlying)

compiler/src/dotty/tools/dotc/printing/Formatting.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ object Formatting {
2828
case arg: Showable =>
2929
try arg.show
3030
catch {
31+
case ex: CyclicReference => "... (caught cyclic reference) ..."
3132
case NonFatal(ex)
3233
if !ctx.mode.is(Mode.PrintShowExceptions) &&
3334
!ctx.settings.YshowPrintErrors.value =>

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
166166
changePrec(OrTypePrec) { toText(tp1) ~ " | " ~ atPrec(OrTypePrec + 1) { toText(tp2) } }
167167
case MatchType(bound, scrutinee, cases) =>
168168
changePrec(GlobalPrec) {
169-
def caseText(tp: Type): Text = "case " ~ toText(tp)
169+
def caseText(tp: Type): Text = tp match {
170+
case defn.MatchCase(pat, body) => "case " ~ toText(pat) ~ " => " ~ toText(body)
171+
case _ => "case " ~ toText(tp)
172+
}
170173
def casesText = Text(cases.map(caseText), "\n")
171174
atPrec(InfixPrec) { toText(scrutinee) } ~
172175
keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~

compiler/src/dotty/tools/dotc/reporting/Reporter.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import core.Mode
1313
import dotty.tools.dotc.core.Symbols.{Symbol, NoSymbol}
1414
import diagnostic.messages._
1515
import diagnostic._
16+
import ast.{tpd, Trees}
1617
import Message._
1718

1819
object Reporter {
@@ -89,21 +90,25 @@ trait Reporting { this: Context =>
8990
}
9091

9192
def warning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
92-
reportWarning(new Warning(msg, pos))
93+
reportWarning(new Warning(msg, addInlineds(pos)))
9394

94-
def strictWarning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
95-
if (this.settings.strict.value) error(msg, pos)
96-
else reportWarning(new ExtendMessage(() => msg)(_ + "\n(This would be an error under strict mode)").warning(pos))
95+
def strictWarning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit = {
96+
val fullPos = addInlineds(pos)
97+
if (this.settings.strict.value) error(msg, fullPos)
98+
else reportWarning(new ExtendMessage(() => msg)(_ + "\n(This would be an error under strict mode)").warning(fullPos))
99+
}
97100

98101
def error(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
99-
reporter.report(new Error(msg, pos))
102+
reporter.report(new Error(msg, addInlineds(pos)))
100103

101-
def errorOrMigrationWarning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
102-
if (ctx.scala2Mode) migrationWarning(msg, pos) else error(msg, pos)
104+
def errorOrMigrationWarning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit = {
105+
val fullPos = addInlineds(pos)
106+
if (ctx.scala2Mode) migrationWarning(msg, fullPos) else error(msg, fullPos)
107+
}
103108

104109
def restrictionError(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
105110
reporter.report {
106-
new ExtendMessage(() => msg)(m => s"Implementation restriction: $m").error(pos)
111+
new ExtendMessage(() => msg)(m => s"Implementation restriction: $m").error(addInlineds(pos))
107112
}
108113

109114
def incompleteInputError(msg: => Message, pos: SourcePosition = NoSourcePosition)(implicit ctx: Context): Unit =
@@ -135,6 +140,14 @@ trait Reporting { this: Context =>
135140

136141
def debugwarn(msg: => String, pos: SourcePosition = NoSourcePosition): Unit =
137142
if (this.settings.Ydebug.value) warning(msg, pos)
143+
144+
private def addInlineds(pos: SourcePosition)(implicit ctx: Context) = {
145+
def recur(pos: SourcePosition, inlineds: List[Trees.Tree[_]]): SourcePosition = inlineds match {
146+
case inlined :: inlineds1 => pos.withOuter(recur(inlined.sourcePos, inlineds1))
147+
case Nil => pos
148+
}
149+
recur(pos, tpd.enclosingInlineds)
150+
}
138151
}
139152

140153
/**

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ trait TypeAssigner {
487487
}
488488
HKTypeLambda.fromParams(
489489
params(new mutable.ListBuffer[TypeSymbol](), pat).toList,
490-
defn.FunctionOf(pat.tpe :: Nil, body.tpe))
490+
defn.MatchCase(pat.tpe, body.tpe))
491491
}
492492
else body.tpe
493493
tree.withType(ownType)

compiler/src/dotty/tools/dotc/typer/VarianceChecker.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ object VarianceChecker {
3434
val paramVarianceStr = if (v == 0) "contra" else "co"
3535
val occursStr = variance match {
3636
case -1 => "contra"
37-
case 0 => "non"
37+
case 0 => "in"
3838
case 1 => "co"
3939
}
4040
val pos = tree.tparams
@@ -123,18 +123,19 @@ class VarianceChecker()(implicit ctx: Context) {
123123
def apply(status: Option[VarianceError], tp: Type): Option[VarianceError] = trace(s"variance checking $tp of $base at $variance", variances) {
124124
try
125125
if (status.isDefined) status
126-
else tp match {
126+
else tp.normalized match {
127127
case tp: TypeRef =>
128128
val sym = tp.symbol
129129
if (sym.variance != 0 && base.isContainedIn(sym.owner)) checkVarianceOfSymbol(sym)
130-
else if (sym.isAliasType) this(status, sym.info.bounds.hi)
131-
else foldOver(status, tp)
130+
else sym.info match {
131+
case MatchAlias(_) => foldOver(status, tp)
132+
case TypeAlias(alias) => this(status, alias)
133+
case _ => foldOver(status, tp)
134+
}
132135
case tp: MethodOrPoly =>
133136
this(status, tp.resultType) // params will be checked in their TypeDef or ValDef nodes.
134137
case AnnotatedType(_, annot) if annot.symbol == defn.UncheckedVarianceAnnot =>
135138
status
136-
case tp: MatchType =>
137-
apply(status, tp.bound)
138139
case tp: ClassInfo =>
139140
foldOver(status, tp.classParents)
140141
case _ =>
@@ -179,7 +180,7 @@ class VarianceChecker()(implicit ctx: Context) {
179180
sym.is(PrivateLocal) ||
180181
sym.name.is(InlineAccessorName) || // TODO: should we exclude all synthetic members?
181182
sym.is(TypeParam) && sym.owner.isClass // already taken care of in primary constructor of class
182-
tree match {
183+
try tree match {
183184
case defn: MemberDef if skip =>
184185
ctx.debuglog(s"Skipping variance check of ${sym.showDcl}")
185186
case tree: TypeDef =>
@@ -196,6 +197,9 @@ class VarianceChecker()(implicit ctx: Context) {
196197
vparamss foreach (_ foreach traverse)
197198
case _ =>
198199
}
200+
catch {
201+
case ex: TypeError => ctx.error(ex.toMessage, tree.sourcePos.focus)
202+
}
199203
}
200204
}
201205
}

compiler/test-resources/repl/i5218

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
scala> val tuple = (1, "2", 3L)
22
val tuple: (Int, String, Long) = (1,2,3)
33
scala> 0.0 *: tuple
4-
val res0: Double *: (Int, String, Long)(tuple) = (0.0,1,2,3)
4+
val res0: (Double, Int, String, Long) = (0.0,1,2,3)
55
scala> tuple ++ tuple
66
val res1: Int *: String *: Long *:
77
scala.Tuple.Concat[Unit, (Int, String, Long)(tuple)] = (1,2,3,1,2,3)

compiler/test/dotc/run-test-pickling.blacklist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ t3452g
55
t7374
66
tuples1.scala
77
tuples1a.scala
8+
typeclass-derivation1.scala
89
typeclass-derivation2.scala
910
typeclass-derivation2a.scala
1011
typeclass-derivation3.scala

0 commit comments

Comments
 (0)