Skip to content

Commit 592f664

Browse files
committed
Ban classes that incompatibly refine type params
In upickle there was a misuse of Any in a contravariant position.
1 parent e560c2d commit 592f664

16 files changed

+287
-63
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+2
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,9 @@ class Definitions {
17651765
Set[Symbol](ComparableClass, ProductClass, SerializableClass,
17661766
// add these for now, until we had a chance to retrofit 2.13 stdlib
17671767
// we should do a more through sweep through it then.
1768+
requiredClass("scala.collection.IterableFactoryDefaults"),
17681769
requiredClass("scala.collection.SortedOps"),
1770+
requiredClass("scala.collection.StrictOptimizedSetOps"),
17691771
requiredClass("scala.collection.StrictOptimizedSortedSetOps"),
17701772
requiredClass("scala.collection.generic.DefaultSerializable"),
17711773
requiredClass("scala.collection.generic.IsIterable"),

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

+22-19
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ trait PatternTypeConstrainer { self: TypeComparer =>
209209
* are used to infer type arguments to Unapply trees.
210210
*
211211
* ## Invariant refinement
212-
* Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
213-
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`. This is violated if `A` is variant:
212+
* Essentially, we say that `D[B] extends C[B]` refines parameter `A` of `trait C[A]` invariantly if
213+
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`.
214+
* This is violated if `A` is variant and `C` is mixed in with an incompatible type argument:
214215
*
215216
* trait C[+A]
216217
* trait D[+B](val b: B) extends C[B]
@@ -224,29 +225,30 @@ trait PatternTypeConstrainer { self: TypeComparer =>
224225
* }
225226
*
226227
* It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`.
227-
* Note, however, that if `D` was a final class, we *could* rely on that relationship.
228-
* To support typical case classes, we also assume that this relationship holds for them and their parent traits.
229-
* This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those
230-
* case classes without also appropriately extending the relevant case class
231-
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
228+
* Note, however, that if `D` was a concrete class, we can rely on that relationship.
229+
* We can assume this relationship holds for them and their parent traits
230+
* by checking that classes inheriting from those classes do not mix-in any parent traits
231+
* with a type parameter that isn't the same type, a subtype, or a super type, depending on if the
232+
* trait's parameter is invariant, covariant or contravariant, respectively
233+
* (see `RefChecks#checkClassInheritanceInvariant`).
232234
*/
233235
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
234236
def refinementIsInvariant(tp: Type): Boolean = tp match {
235237
case tp: SingletonType => true
236-
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
238+
case tp: ClassInfo => tp.cls.is(Final)
237239
case tp: TypeProxy => refinementIsInvariant(tp.superType)
238240
case _ => false
239241
}
242+
def refinementIsInvariant2(tp: Type): Boolean = tp match
243+
case tp: SingletonType => true
244+
case tp: ClassInfo => !tp.cls.isOneOf(AbstractOrTrait) || tp.cls.isOneOf(Private | Sealed)
245+
case tp: TypeProxy => refinementIsInvariant2(tp.superType)
246+
case _ => false
240247

241-
def widenVariantParams(tp: Type) = tp match {
242-
case tp @ AppliedType(tycon, args) =>
243-
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
244-
if (tparam.paramVarianceSign != 0) TypeBounds.empty else arg
245-
)
246-
tp.derivedAppliedType(tycon, args1)
247-
case tp =>
248-
tp
249-
}
248+
extension (tp: Type) def isAbstract: Boolean = tp.stripped match
249+
case _: TypeParamRef => true
250+
case tp: TypeRef => !tp.symbol.isClass
251+
case _ => false
250252

251253
val patternCls = patternTp.classSymbol
252254
val scrutineeCls = scrutineeTp.classSymbol
@@ -269,10 +271,11 @@ trait PatternTypeConstrainer { self: TypeComparer =>
269271
val result =
270272
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
271273
val variance = param.paramVarianceSign
272-
if variance == 0 || assumeInvariantRefinement ||
274+
if variance == 0 || assumeInvariantRefinement
275+
|| refinementIsInvariant2(patternTp) && (argP.isAbstract || patternTp.argInfos.contains(argP))
273276
// As a special case, when pattern and scrutinee types have the same type constructor,
274277
// we infer better bounds for pattern-bound abstract types.
275-
argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
278+
|| argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
276279
then
277280
val TypeBounds(loS, hiS) = argS.bounds
278281
val TypeBounds(loP, hiP) = argP.bounds

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+6-3
Original file line numberDiff line numberDiff line change
@@ -2074,9 +2074,9 @@ object SymDenotations {
20742074
required: FlagSet = EmptyFlags, excluded: FlagSet = EmptyFlags)(using Context): Denotation =
20752075
membersNamedNoShadowingBasedOnFlags(name, required, excluded).asSeenFrom(pre).toDenot(pre)
20762076

2077-
/** Compute tp.baseType(this) */
2078-
final def baseTypeOf(tp: Type)(using Context): Type = {
2079-
val btrCache = baseTypeCache
2077+
/** Compute tp.baseType(this) or tp.baseType(this, without) */
2078+
final def baseTypeOf(tp: Type, without: Option[Symbol] = None)(using Context): Type = {
2079+
val btrCache = if without.isEmpty then baseTypeCache else new BaseTypeMap()
20802080
def inCache(tp: Type) = tp match
20812081
case tp: CachedType => btrCache.contains(tp)
20822082
case _ => false
@@ -2130,6 +2130,8 @@ object SymDenotations {
21302130
val baseTp =
21312131
if (tpSym eq symbol)
21322132
tp
2133+
else if without.exists(tpSym eq _) then
2134+
defn.AnyType
21332135
else if (isOwnThis)
21342136
if (clsd.baseClassSet.contains(symbol))
21352137
if (symbol.isStatic && symbol.typeParams.isEmpty) symbol.typeRef
@@ -2156,6 +2158,7 @@ object SymDenotations {
21562158
btrCache(tp) = NoPrefix
21572159
val baseTp =
21582160
if (tycon.typeSymbol eq symbol) tp
2161+
else if without.exists(tycon.typeSymbol eq _) then defn.AnyType
21592162
else (tycon.typeParams: @unchecked) match {
21602163
case LambdaParam(_, _) :: _ =>
21612164
recur(tp.superType)

compiler/src/dotty/tools/dotc/core/Types.scala

+6
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,12 @@ object Types {
11291129
}
11301130
}
11311131

1132+
/** `basetype`, but ignoring any base classes that have the given `without` class symbol. */
1133+
final def baseTypeWithout(base: Symbol, without: Symbol)(using Context): Type =
1134+
base.denot match
1135+
case classd: ClassDenotation => classd.baseTypeOf(this, Some(without))
1136+
case _ => NoType
1137+
11321138
def & (that: Type)(using Context): Type = {
11331139
record("&")
11341140
TypeComparer.glb(this, that)

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

+26-12
Original file line numberDiff line numberDiff line change
@@ -770,16 +770,18 @@ object RefChecks {
770770
}
771771
}
772772

773-
/** Check that inheriting a case class does not constitute a variant refinement
774-
* of a base type of the case class. It is because of this restriction that we
775-
* can assume invariant refinement for case classes in `constrainPatternType`.
773+
/** Check that inheriting a class does not constitute a variant refinement
774+
* of a base type of the class. It is because of this restriction that we
775+
* can assume invariant refinement for concrete classes in `constrainPatternType`.
776776
*/
777-
def checkCaseClassInheritanceInvariant() =
778-
for (caseCls <- clazz.info.baseClasses.tail.find(_.is(Case)))
779-
for (baseCls <- caseCls.info.baseClasses.tail)
777+
def checkClassInheritanceInvariant() =
778+
for (middle <- clazz.info.baseClasses.tail.filter(!_.isTransparentTrait))
779+
for (baseCls <- middle.info.baseClasses.tail)
780780
if (baseCls.typeParams.exists(_.paramVarianceSign != 0))
781-
for (problem <- variantInheritanceProblems(baseCls, caseCls, "non-variant", "case "))
781+
val middleStr = if middle.is(Case) then "case " else ""
782+
for (problem <- variantInheritanceProblems(baseCls, middle, "variant", middleStr))
782783
report.errorOrMigrationWarning(problem(), clazz.srcPos, from = `3.0`)
784+
783785
checkNoAbstractMembers()
784786
if (abstractErrors.isEmpty)
785787
checkNoAbstractDecls(clazz)
@@ -788,7 +790,7 @@ object RefChecks {
788790
report.error(abstractErrorMessage, clazz.srcPos)
789791

790792
checkMemberTypesOK()
791-
checkCaseClassInheritanceInvariant()
793+
checkClassInheritanceInvariant()
792794
}
793795

794796
if (!clazz.is(Trait)) {
@@ -825,16 +827,28 @@ object RefChecks {
825827
*/
826828
def variantInheritanceProblems(
827829
baseCls: Symbol, middle: Symbol, baseStr: String, middleStr: String): Option[() => String] = {
830+
if baseCls == middle then return None
828831
val superBT = self.baseType(middle)
829-
val thisBT = self.baseType(baseCls)
830832
val combinedBT = superBT.baseType(baseCls)
831-
if (combinedBT =:= thisBT) None // ok
833+
val withoutMiddleBT = self.baseTypeWithout(baseCls, middle)
834+
val allOk = (combinedBT, withoutMiddleBT) match
835+
case (AppliedType(tycon, args1), AppliedType(_, args2)) =>
836+
val superBTArgs = superBT.argInfos.toSet
837+
tycon.typeParams.lazyZip(args1).lazyZip(args2).forall { (param, arg1, arg2) =>
838+
if superBTArgs.contains(arg1) then
839+
val variance = param.paramVarianceSign
840+
(variance > 0 || (arg2 <:< arg1)) &&
841+
(variance < 0 || (arg1 <:< arg2))
842+
else true // e.g. CovBoth in neg/i11834
843+
}
844+
case _ => combinedBT =:= self.baseType(baseCls)
845+
if allOk then None // ok
832846
else
833847
Some(() =>
834848
em"""illegal inheritance: $clazz inherits conflicting instances of $baseStr base $baseCls.
835849
|
836-
| Direct basetype: $thisBT
837-
| Basetype via $middleStr$middle: $combinedBT""")
850+
| Basetype via $middleStr$middle: $combinedBT
851+
| Basetype without $middleStr$middle: $withoutMiddleBT""")
838852
}
839853

840854
/* Returns whether there is a symbol declared in class `inclazz`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// like neg-custom-args/isInstanceOf/3324g,
2+
// but verifying the fatal type test/unchecked warning
3+
// emitted during Erasure
4+
// by not being trumped by the fatal refcheck warning on C subclass
5+
class Test {
6+
trait A[+T]
7+
class B[T] extends A[T]
8+
9+
def quux[T](a: A[T]): Unit = a match {
10+
case _: B[T] => // error
11+
}
12+
}
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,5 @@
11
class Test {
22
trait A[+T]
33
class B[T] extends A[T]
4-
class C[T] extends B[Any] with A[T]
5-
6-
def foo[T](c: C[T]): Unit = c match {
7-
case _: B[T] => // error
8-
}
9-
10-
def bar[T](b: B[T]): Unit = b match {
11-
case _: A[T] =>
12-
}
13-
14-
def quux[T](a: A[T]): Unit = a match {
15-
case _: B[T] => // error!!
16-
}
17-
18-
quux(new C[Int])
4+
class C[T] extends B[Any] with A[T] // error
195
}

tests/neg/i11018.scala

+15-5
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,19 @@ trait CTrait[+A](val a: A) {
1313
trait DTrait[+B] extends CTrait[B]
1414
trait DClass[+B] extends CClass[B]
1515

16-
final class F1 extends DTrait[Foo] with CTrait[Bar](new Bar) // error: illegal parameter
17-
final class F2 extends CTrait[Bar](new Bar) with DTrait[Foo] // error: illegal parameter
18-
final class F3 extends DClass[Foo] with CClass[Bar](new Bar) // error: illegal parameter
19-
final class F4 extends CClass[Bar](new Bar) with DClass[Foo] // error: illegal parameter
16+
final class F1 // error: illegal inheritance
17+
extends DTrait[Foo]
18+
with CTrait[Bar](new Bar) // error: illegal parameter
19+
final class F2 // error: illegal inheritance
20+
extends CTrait[Bar](new Bar) // error: illegal parameter
21+
with DTrait[Foo]
22+
final class F3 // error: illegal inheritance
23+
extends DClass[Foo]
24+
with CClass[Bar](new Bar) // error: illegal parameter
25+
final class F4 // error: illegal inheritance
26+
extends CClass[Bar](new Bar) // error: illegal parameter
27+
with DClass[Foo]
2028

21-
final class F5 extends DTrait[Foo] with CTrait[Foo & Bar](new Bar with Foo { def name = "hello"}) // ok
29+
final class F5 // error: illegal inheritance
30+
extends DTrait[Foo]
31+
with CTrait[Foo & Bar](new Bar with Foo { def name = "hello"})

0 commit comments

Comments
 (0)