diff --git a/compiler/src/dotty/tools/dotc/core/Hashable.scala b/compiler/src/dotty/tools/dotc/core/Hashable.scala index 713555bed517..d62fe71f03df 100644 --- a/compiler/src/dotty/tools/dotc/core/Hashable.scala +++ b/compiler/src/dotty/tools/dotc/core/Hashable.scala @@ -108,6 +108,9 @@ trait Hashable { protected final def doHash(bs: Binders, x1: Any, tp2: Type, tps3: List[Type]): Int = finishHash(bs, hashing.mix(hashSeed, x1.hashCode), 1, tp2, tps3) + protected final def doHash(x1: Any, x2: Int): Int = + finishHash(hashing.mix(hashing.mix(hashSeed, x1.hashCode), x2), 1) + protected final def doHash(x1: Int, x2: Int): Int = finishHash(hashing.mix(hashing.mix(hashSeed, x1), x2), 1) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 363a9e34230e..b9420280fce8 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1217,7 +1217,7 @@ object Types { case _: TypeRef | _: MethodOrPoly => this // fast path for most frequent cases case tp: TermRef => // fast path for next most frequent case if tp.isOverloaded then tp else tp.underlying.widen - case tp: SingletonType => tp.underlying.widen + case tp: SingletonType if tp.isSoft => tp.underlying.widen case tp: ExprType => tp.resultType.widen case tp => val tp1 = tp.stripped @@ -1230,7 +1230,7 @@ object Types { * base type by applying one or more `underlying` dereferences. */ final def widenSingleton(using Context): Type = stripped match { - case tp: SingletonType if !tp.isOverloaded => tp.underlying.widenSingleton + case tp: SingletonType if tp.isSoft && !tp.isOverloaded => tp.underlying.widenSingleton case _ => this } @@ -2025,8 +2025,12 @@ object Types { /** A marker trait for types that are guaranteed to contain only a * single non-null value (they might contain null in addition). */ - trait SingletonType extends TypeProxy with ValueType { + trait SingletonType extends TypeProxy with ValueType with Softenable { def isOverloaded(using Context): Boolean = false + + /** Overriden in [[ConstantType]]. + */ + override def isSoft = true } /** A trait for types that bind other types that refer to them. @@ -2074,6 +2078,10 @@ object Types { } } + trait Softenable { + def isSoft: Boolean + } + // --- NamedTypes ------------------------------------------------------------------ abstract class NamedType extends CachedProxyType with ValueType { self => @@ -2855,15 +2863,15 @@ object Types { abstract case class ConstantType(value: Constant) extends CachedProxyType with SingletonType { override def underlying(using Context): Type = value.tpe - override def computeHash(bs: Binders): Int = doHash(value) + override def computeHash(bs: Binders): Int = doHash(value, if isSoft then 0 else 1) } - final class CachedConstantType(value: Constant) extends ConstantType(value) + final class CachedConstantType(value: Constant, override val isSoft: Boolean = true) extends ConstantType(value) object ConstantType { - def apply(value: Constant)(using Context): ConstantType = { + def apply(value: Constant, soft: Boolean = true)(using Context): ConstantType = { assertUnerased() - unique(new CachedConstantType(value)) + unique(new CachedConstantType(value, soft)) } } @@ -3206,9 +3214,8 @@ object Types { TypeComparer.liftIfHK(tp1, tp2, AndType.make(_, _, checkValid = false), makeHk, _ | _) } - abstract case class OrType(tp1: Type, tp2: Type) extends AndOrType { + abstract case class OrType(tp1: Type, tp2: Type) extends AndOrType, Softenable { def isAnd: Boolean = false - def isSoft: Boolean private var myBaseClassesPeriod: Period = Nowhere private var myBaseClasses: List[ClassSymbol] = _ /** Base classes of are the intersection of the operand base classes. */ @@ -3282,7 +3289,12 @@ object Types { else tp1.atoms | tp2.atoms val tp1w = tp1.widenSingletons val tp2w = tp2.widenSingletons - myWidened = if ((tp1 eq tp1w) && (tp2 eq tp2w)) this else tp1w | tp2w + myWidened = + if isSoft then + if ((tp1 eq tp1w) && (tp2 eq tp2w)) this else tp1w | tp2w + else + derivedOrType(tp1w, tp2w) + atomsRunId = ctx.runId override def atoms(using Context): Atoms = diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index af186e825591..3785b2141e6e 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -438,7 +438,7 @@ class TreeUnpickler(reader: TastyReader, case BYNAMEtype => ExprType(readType()) case _ => - ConstantType(readConstant(tag)) + ConstantType(readConstant(tag), soft = false) } if (tag < firstLengthTreeTag) readSimpleType() else readLengthType() diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 4824031f12bc..0a731dadaa0b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -935,7 +935,7 @@ trait Implicits: case TypeBounds(lo, hi) if lo.ne(hi) && !t.symbol.is(Opaque) => apply(hi) case _ => t } - case t: SingletonType => + case t: SingletonType if t.isSoft => apply(t.widen) case t: RefinedType => apply(t.parent) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 4c0c10f7c026..4fa48611ae4e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1879,7 +1879,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedSingletonTypeTree(tree: untpd.SingletonTypeTree)(using Context): SingletonTypeTree = { val ref1 = typedExpr(tree.ref) checkStable(ref1.tpe, tree.srcPos, "singleton type") - assignType(cpy.SingletonTypeTree(tree)(ref1), ref1) + + val ref2 = ref1.tpe match + case ConstantType(c) => ref1.withType(ConstantType(c, soft = false)) + case _ => ref1 + + assignType(cpy.SingletonTypeTree(tree)(ref2), ref2) } def typedRefinedTypeTree(tree: untpd.RefinedTypeTree)(using Context): TypTree = { diff --git a/tests/pos/widen-singletons.scala b/tests/pos/widen-singletons.scala new file mode 100644 index 000000000000..3a731c9c5818 --- /dev/null +++ b/tests/pos/widen-singletons.scala @@ -0,0 +1,12 @@ +object Test: + def is2(x: 2) = true + + def testValType() = + val x: 2 = 2 + val v = x + is2(v) + + def testDefReturnType() = + def f(): 2 = 2 + val v = f() + is2(v) diff --git a/tests/pos/widen-union.scala b/tests/pos/widen-union.scala index b0b64f0dc6c6..385c9a8da946 100644 --- a/tests/pos/widen-union.scala +++ b/tests/pos/widen-union.scala @@ -13,7 +13,6 @@ object Test2: || xs.corresponds(ys)(consistent(_, _)) // error, found: Any, required: Int | String object Test3: - def g[X](x: X | String): Int = ??? def y: Boolean | String = ??? g[Boolean](y) @@ -21,4 +20,23 @@ object Test3: g[Boolean](identity(y)) g(identity(y)) +object TestSingletonsInUnions: + def is2Or3(a: 2 | 3) = true + + def testValType() = + val x: 2 | 3 = 2 + val v = x + is2Or3(v) + + def testDefReturnType() = + def f(): 2 | 3 = 2 + val v = f() + is2Or3(v) + + def testSoftUnionInHardUnion() = + def isStringOr3(a: String | 3) = true + def f(x: String): x.type | 3 = 3 + val b: Boolean = true + val v = f(if b then "a" else "b") + isStringOr3(v)