Skip to content

Commit 2842762

Browse files
committed
Preserve singletons in unions when they're explicitly written in the code
Note that we do not keep singletons in pattern alternatives like `case x @ (1 | 2)` because if there are many alternatives like `JavaScanner#fetchToken`, we end up with deep subtyping checks. Fixes #829
1 parent b927f66 commit 2842762

File tree

6 files changed

+29
-10
lines changed

6 files changed

+29
-10
lines changed

src/dotty/tools/dotc/core/TypeComparer.scala

+8-8
Original file line numberDiff line numberDiff line change
@@ -790,9 +790,9 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
790790
(defn.AnyType /: tps)(glb)
791791

792792
/** The least upper bound of two types
793-
* @note We do not admit singleton types in or-types as lubs.
793+
* @param keepSingletons If true, do not widen singletons when forming an OrType
794794
*/
795-
def lub(tp1: Type, tp2: Type): Type = /*>|>*/ ctx.traceIndented(s"lub(${tp1.show}, ${tp2.show})", subtyping, show = true) /*<|<*/ {
795+
def lub(tp1: Type, tp2: Type, keepSingletons: Boolean = false): Type = /*>|>*/ ctx.traceIndented(s"lub(${tp1.show}, ${tp2.show}, $keepSingletons)", subtyping, show = true) /*<|<*/ {
796796
if (tp1 eq tp2) tp1
797797
else if (!tp1.exists) tp1
798798
else if (!tp2.exists) tp2
@@ -805,8 +805,8 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
805805
val t2 = mergeIfSuper(tp2, tp1)
806806
if (t2.exists) t2
807807
else {
808-
val tp1w = tp1.widen
809-
val tp2w = tp2.widen
808+
val tp1w = if (keepSingletons) tp1.widenExpr else tp1.widen
809+
val tp2w = if (keepSingletons) tp2.widenExpr else tp2.widen
810810
if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w)
811811
else orType(tp1w, tp2w) // no need to check subtypes again
812812
}
@@ -816,7 +816,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
816816

817817
/** The least upper bound of a list of types */
818818
final def lub(tps: List[Type]): Type =
819-
(defn.NothingType /: tps)(lub)
819+
(defn.NothingType /: tps)(lub(_, _))
820820

821821
/** Merge `t1` into `tp2` if t1 is a subtype of some &-summand of tp2.
822822
*/
@@ -1207,9 +1207,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
12071207
super.hasMatchingMember(name, tp1, tp2)
12081208
}
12091209

1210-
override def lub(tp1: Type, tp2: Type) =
1211-
traceIndented(s"lub(${show(tp1)}, ${show(tp2)})") {
1212-
super.lub(tp1, tp2)
1210+
override def lub(tp1: Type, tp2: Type, keepSingletons: Boolean = false) =
1211+
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, $keepSingletons)") {
1212+
super.lub(tp1, tp2, keepSingletons)
12131213
}
12141214

12151215
override def glb(tp1: Type, tp2: Type) =

src/dotty/tools/dotc/core/TypeOps.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object.
246246
case AndType(l, r) =>
247247
simplify(l, theMap) & simplify(r, theMap)
248248
case OrType(l, r) =>
249-
simplify(l, theMap) | simplify(r, theMap)
249+
ctx.typeComparer.lub(simplify(l, theMap), simplify(r, theMap), keepSingletons = true)
250250
case _ =>
251251
(if (theMap != null) theMap else new SimplifyMap).mapOver(tp)
252252
}

src/dotty/tools/dotc/typer/TypeAssigner.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ trait TypeAssigner {
366366
tree.withType(left.tpe & right.tpe)
367367

368368
def assignType(tree: untpd.OrTypeTree, left: Tree, right: Tree)(implicit ctx: Context) =
369-
tree.withType(left.tpe | right.tpe)
369+
tree.withType(ctx.typeComparer.lub(left.tpe, right.tpe, keepSingletons = true))
370370

371371
// RefinedTypeTree is missing, handled specially in Typer and Unpickler.
372372

test/dotc/tests.scala

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class tests extends CompilerTest {
160160
@Test def neg_validate = compileFile(negDir, "validate", xerrors = 18)
161161
@Test def neg_validateParsing = compileFile(negDir, "validate-parsing", xerrors = 7)
162162
@Test def neg_validateRefchecks = compileFile(negDir, "validate-refchecks", xerrors = 2)
163+
@Test def neg_singletonsLubs = compileFile(negDir, "singletons-lubs", xerrors = 2)
163164

164165
@Test def run_all = runFiles(runDir)
165166

tests/neg/singletons-lubs.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
object Test {
2+
def oneOrTwo(x: 1 | 2): 1 | 2 = x
3+
def test: Unit = {
4+
val foo: 3 | 4 = 1 // error
5+
oneOrTwo(foo) // error
6+
}
7+
}

tests/pos/singletons-lubs.scala

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
object Test {
2+
def oneOrTwo(x: 1 | 2): 1 | 2 = x
3+
def test: Unit = {
4+
val foo: 1 | 2 = 1
5+
oneOrTwo(oneOrTwo(foo))
6+
1 match {
7+
case x: (1 | 2) => oneOrTwo(x)
8+
//case x @ (1 | 2) => oneOrTwo(x) // disallowed to avoid deep subtyping checks
9+
}
10+
}
11+
}

0 commit comments

Comments
 (0)