Skip to content

Ban classes that incompatibly refine type params #14820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1765,7 +1765,9 @@ class Definitions {
Set[Symbol](ComparableClass, ProductClass, SerializableClass,
// add these for now, until we had a chance to retrofit 2.13 stdlib
// we should do a more through sweep through it then.
requiredClass("scala.collection.IterableFactoryDefaults"),
requiredClass("scala.collection.SortedOps"),
requiredClass("scala.collection.StrictOptimizedSetOps"),
requiredClass("scala.collection.StrictOptimizedSortedSetOps"),
requiredClass("scala.collection.generic.DefaultSerializable"),
requiredClass("scala.collection.generic.IsIterable"),
Expand Down
41 changes: 22 additions & 19 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ trait PatternTypeConstrainer { self: TypeComparer =>
* are used to infer type arguments to Unapply trees.
*
* ## Invariant refinement
* Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`. This is violated if `A` is variant:
* Essentially, we say that `D[B] extends C[B]` refines parameter `A` of `trait C[A]` invariantly if
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`.
* This is violated if `A` is variant and `C` is mixed in with an incompatible type argument:
*
* trait C[+A]
* trait D[+B](val b: B) extends C[B]
Expand All @@ -224,29 +225,30 @@ trait PatternTypeConstrainer { self: TypeComparer =>
* }
*
* It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`.
* Note, however, that if `D` was a final class, we *could* rely on that relationship.
* To support typical case classes, we also assume that this relationship holds for them and their parent traits.
* This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those
* case classes without also appropriately extending the relevant case class
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
* Note, however, that if `D` was a concrete class, we can rely on that relationship.
* We can assume this relationship holds for them and their parent traits
* by checking that classes inheriting from those classes do not mix-in any parent traits
* with a type parameter that isn't the same type, a subtype, or a super type, depending on if the
* trait's parameter is invariant, covariant or contravariant, respectively
* (see `RefChecks#checkClassInheritanceInvariant`).
*/
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
def refinementIsInvariant(tp: Type): Boolean = tp match {
case tp: SingletonType => true
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
case tp: ClassInfo => tp.cls.is(Final)
case tp: TypeProxy => refinementIsInvariant(tp.superType)
case _ => false
}
def refinementIsInvariant2(tp: Type): Boolean = tp match
case tp: SingletonType => true
case tp: ClassInfo => !tp.cls.isOneOf(AbstractOrTrait) || tp.cls.isOneOf(Private | Sealed)
case tp: TypeProxy => refinementIsInvariant2(tp.superType)
case _ => false

def widenVariantParams(tp: Type) = tp match {
case tp @ AppliedType(tycon, args) =>
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
if (tparam.paramVarianceSign != 0) TypeBounds.empty else arg
)
tp.derivedAppliedType(tycon, args1)
case tp =>
tp
}
extension (tp: Type) def isAbstract: Boolean = tp.stripped match
case _: TypeParamRef => true
case tp: TypeRef => !tp.symbol.isClass
case _ => false

val patternCls = patternTp.classSymbol
val scrutineeCls = scrutineeTp.classSymbol
Expand All @@ -269,10 +271,11 @@ trait PatternTypeConstrainer { self: TypeComparer =>
val result =
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
val variance = param.paramVarianceSign
if variance == 0 || assumeInvariantRefinement ||
if variance == 0 || assumeInvariantRefinement
|| refinementIsInvariant2(patternTp) && (argP.isAbstract || patternTp.argInfos.contains(argP))
// As a special case, when pattern and scrutinee types have the same type constructor,
// we infer better bounds for pattern-bound abstract types.
argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
|| argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
then
val TypeBounds(loS, hiS) = argS.bounds
val TypeBounds(loP, hiP) = argP.bounds
Expand Down
9 changes: 6 additions & 3 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2074,9 +2074,9 @@ object SymDenotations {
required: FlagSet = EmptyFlags, excluded: FlagSet = EmptyFlags)(using Context): Denotation =
membersNamedNoShadowingBasedOnFlags(name, required, excluded).asSeenFrom(pre).toDenot(pre)

/** Compute tp.baseType(this) */
final def baseTypeOf(tp: Type)(using Context): Type = {
val btrCache = baseTypeCache
/** Compute tp.baseType(this) or tp.baseType(this, without) */
final def baseTypeOf(tp: Type, without: Option[Symbol] = None)(using Context): Type = {
val btrCache = if without.isEmpty then baseTypeCache else new BaseTypeMap()
def inCache(tp: Type) = tp match
case tp: CachedType => btrCache.contains(tp)
case _ => false
Expand Down Expand Up @@ -2130,6 +2130,8 @@ object SymDenotations {
val baseTp =
if (tpSym eq symbol)
tp
else if without.exists(tpSym eq _) then
defn.AnyType
else if (isOwnThis)
if (clsd.baseClassSet.contains(symbol))
if (symbol.isStatic && symbol.typeParams.isEmpty) symbol.typeRef
Expand All @@ -2156,6 +2158,7 @@ object SymDenotations {
btrCache(tp) = NoPrefix
val baseTp =
if (tycon.typeSymbol eq symbol) tp
else if without.exists(tycon.typeSymbol eq _) then defn.AnyType
else (tycon.typeParams: @unchecked) match {
case LambdaParam(_, _) :: _ =>
recur(tp.superType)
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,12 @@ object Types {
}
}

/** `basetype`, but ignoring any base classes that have the given `without` class symbol. */
final def baseTypeWithout(base: Symbol, without: Symbol)(using Context): Type =
base.denot match
case classd: ClassDenotation => classd.baseTypeOf(this, Some(without))
case _ => NoType

def & (that: Type)(using Context): Type = {
record("&")
TypeComparer.glb(this, that)
Expand Down
38 changes: 26 additions & 12 deletions compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -770,16 +770,18 @@ object RefChecks {
}
}

/** Check that inheriting a case class does not constitute a variant refinement
* of a base type of the case class. It is because of this restriction that we
* can assume invariant refinement for case classes in `constrainPatternType`.
/** Check that inheriting a class does not constitute a variant refinement
* of a base type of the class. It is because of this restriction that we
* can assume invariant refinement for concrete classes in `constrainPatternType`.
*/
def checkCaseClassInheritanceInvariant() =
for (caseCls <- clazz.info.baseClasses.tail.find(_.is(Case)))
for (baseCls <- caseCls.info.baseClasses.tail)
def checkClassInheritanceInvariant() =
for (middle <- clazz.info.baseClasses.tail.filter(!_.isTransparentTrait))
for (baseCls <- middle.info.baseClasses.tail)
if (baseCls.typeParams.exists(_.paramVarianceSign != 0))
for (problem <- variantInheritanceProblems(baseCls, caseCls, "non-variant", "case "))
val middleStr = if middle.is(Case) then "case " else ""
for (problem <- variantInheritanceProblems(baseCls, middle, "variant", middleStr))
report.errorOrMigrationWarning(problem(), clazz.srcPos, from = `3.0`)

checkNoAbstractMembers()
if (abstractErrors.isEmpty)
checkNoAbstractDecls(clazz)
Expand All @@ -788,7 +790,7 @@ object RefChecks {
report.error(abstractErrorMessage, clazz.srcPos)

checkMemberTypesOK()
checkCaseClassInheritanceInvariant()
checkClassInheritanceInvariant()
}

if (!clazz.is(Trait)) {
Expand Down Expand Up @@ -825,16 +827,28 @@ object RefChecks {
*/
def variantInheritanceProblems(
baseCls: Symbol, middle: Symbol, baseStr: String, middleStr: String): Option[() => String] = {
if baseCls == middle then return None
val superBT = self.baseType(middle)
val thisBT = self.baseType(baseCls)
val combinedBT = superBT.baseType(baseCls)
if (combinedBT =:= thisBT) None // ok
val withoutMiddleBT = self.baseTypeWithout(baseCls, middle)
val allOk = (combinedBT, withoutMiddleBT) match
case (AppliedType(tycon, args1), AppliedType(_, args2)) =>
val superBTArgs = superBT.argInfos.toSet
tycon.typeParams.lazyZip(args1).lazyZip(args2).forall { (param, arg1, arg2) =>
if superBTArgs.contains(arg1) then
val variance = param.paramVarianceSign
(variance > 0 || (arg2 <:< arg1)) &&
(variance < 0 || (arg1 <:< arg2))
else true // e.g. CovBoth in neg/i11834
}
case _ => combinedBT =:= self.baseType(baseCls)
if allOk then None // ok
else
Some(() =>
em"""illegal inheritance: $clazz inherits conflicting instances of $baseStr base $baseCls.
|
| Direct basetype: $thisBT
| Basetype via $middleStr$middle: $combinedBT""")
| Basetype via $middleStr$middle: $combinedBT
| Basetype without $middleStr$middle: $withoutMiddleBT""")
}

/* Returns whether there is a symbol declared in class `inclazz`
Expand Down
12 changes: 12 additions & 0 deletions tests/neg-custom-args/isInstanceOf/3324g.erasure.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// like neg-custom-args/isInstanceOf/3324g,
// but verifying the fatal type test/unchecked warning
// emitted during Erasure
// by not being trumped by the fatal refcheck warning on C subclass
class Test {
trait A[+T]
class B[T] extends A[T]

def quux[T](a: A[T]): Unit = a match {
case _: B[T] => // error
}
}
16 changes: 1 addition & 15 deletions tests/neg-custom-args/isInstanceOf/3324g.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,5 @@
class Test {
trait A[+T]
class B[T] extends A[T]
class C[T] extends B[Any] with A[T]

def foo[T](c: C[T]): Unit = c match {
case _: B[T] => // error
}

def bar[T](b: B[T]): Unit = b match {
case _: A[T] =>
}

def quux[T](a: A[T]): Unit = a match {
case _: B[T] => // error!!
}

quux(new C[Int])
class C[T] extends B[Any] with A[T] // error
}
20 changes: 15 additions & 5 deletions tests/neg/i11018.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,19 @@ trait CTrait[+A](val a: A) {
trait DTrait[+B] extends CTrait[B]
trait DClass[+B] extends CClass[B]

final class F1 extends DTrait[Foo] with CTrait[Bar](new Bar) // error: illegal parameter
final class F2 extends CTrait[Bar](new Bar) with DTrait[Foo] // error: illegal parameter
final class F3 extends DClass[Foo] with CClass[Bar](new Bar) // error: illegal parameter
final class F4 extends CClass[Bar](new Bar) with DClass[Foo] // error: illegal parameter
final class F1 // error: illegal inheritance
extends DTrait[Foo]
with CTrait[Bar](new Bar) // error: illegal parameter
final class F2 // error: illegal inheritance
extends CTrait[Bar](new Bar) // error: illegal parameter
with DTrait[Foo]
final class F3 // error: illegal inheritance
extends DClass[Foo]
with CClass[Bar](new Bar) // error: illegal parameter
final class F4 // error: illegal inheritance
extends CClass[Bar](new Bar) // error: illegal parameter
with DClass[Foo]

final class F5 extends DTrait[Foo] with CTrait[Foo & Bar](new Bar with Foo { def name = "hello"}) // ok
final class F5 // error: illegal inheritance
extends DTrait[Foo]
with CTrait[Foo & Bar](new Bar with Foo { def name = "hello"})
Loading