@@ -209,8 +209,9 @@ trait PatternTypeConstrainer { self: TypeComparer =>
209
209
* are used to infer type arguments to Unapply trees.
210
210
*
211
211
* ## Invariant refinement
212
- * Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
213
- * when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`. This is violated if `A` is variant:
212
+ * Essentially, we say that `D[B] extends C[B]` refines parameter `A` of `trait C[A]` invariantly if
213
+ * when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`.
214
+ * This is violated if `A` is variant and `C` is mixed in with an incompatible type argument:
214
215
*
215
216
* trait C[+A]
216
217
* trait D[+B](val b: B) extends C[B]
@@ -224,29 +225,30 @@ trait PatternTypeConstrainer { self: TypeComparer =>
224
225
* }
225
226
*
226
227
* It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`.
227
- * Note, however, that if `D` was a final class, we *could* rely on that relationship.
228
- * To support typical case classes, we also assume that this relationship holds for them and their parent traits.
229
- * This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those
230
- * case classes without also appropriately extending the relevant case class
231
- * (see `RefChecks#checkCaseClassInheritanceInvariant`).
228
+ * Note, however, that if `D` was a concrete class, we can rely on that relationship.
229
+ * We can assume this relationship holds for them and their parent traits
230
+ * by checking that classes inheriting from those classes do not mix-in any parent traits
231
+ * with a type parameter that isn't the same type, a subtype, or a super type, depending on if the
232
+ * trait's parameter is invariant, covariant or contravariant, respectively
233
+ * (see `RefChecks#checkClassInheritanceInvariant`).
232
234
*/
233
235
def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , forceInvariantRefinement : Boolean ): Boolean = {
234
236
def refinementIsInvariant (tp : Type ): Boolean = tp match {
235
237
case tp : SingletonType => true
236
- case tp : ClassInfo => tp.cls.is(Final ) || tp.cls.is( Case )
238
+ case tp : ClassInfo => tp.cls.is(Final )
237
239
case tp : TypeProxy => refinementIsInvariant(tp.superType)
238
240
case _ => false
239
241
}
242
+ def refinementIsInvariant2 (tp : Type ): Boolean = tp match
243
+ case tp : SingletonType => true
244
+ case tp : ClassInfo => ! tp.cls.isOneOf(AbstractOrTrait ) || tp.cls.isOneOf(Private | Sealed )
245
+ case tp : TypeProxy => refinementIsInvariant2(tp.superType)
246
+ case _ => false
240
247
241
- def widenVariantParams (tp : Type ) = tp match {
242
- case tp @ AppliedType (tycon, args) =>
243
- val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
244
- if (tparam.paramVarianceSign != 0 ) TypeBounds .empty else arg
245
- )
246
- tp.derivedAppliedType(tycon, args1)
247
- case tp =>
248
- tp
249
- }
248
+ extension (tp : Type ) def isAbstract : Boolean = tp.stripped match
249
+ case _ : TypeParamRef => true
250
+ case tp : TypeRef => ! tp.symbol.isClass
251
+ case _ => false
250
252
251
253
val patternCls = patternTp.classSymbol
252
254
val scrutineeCls = scrutineeTp.classSymbol
@@ -269,10 +271,11 @@ trait PatternTypeConstrainer { self: TypeComparer =>
269
271
val result =
270
272
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
271
273
val variance = param.paramVarianceSign
272
- if variance == 0 || assumeInvariantRefinement ||
274
+ if variance == 0 || assumeInvariantRefinement
275
+ || refinementIsInvariant2(patternTp) && (argP.isAbstract || patternTp.argInfos.contains(argP))
273
276
// As a special case, when pattern and scrutinee types have the same type constructor,
274
277
// we infer better bounds for pattern-bound abstract types.
275
- argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
278
+ || argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
276
279
then
277
280
val TypeBounds (loS, hiS) = argS.bounds
278
281
val TypeBounds (loP, hiP) = argP.bounds
0 commit comments