Skip to content

Commit 230edd1

Browse files
committed
Better fix: baseclasses intersection that takes bottom types into account
1 parent e65305d commit 230edd1

File tree

4 files changed

+80
-10
lines changed

4 files changed

+80
-10
lines changed

compiler/src/dotty/tools/dotc/core/TypeOps.scala

+10-6
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,16 @@ object TypeOps:
194194
*/
195195
def orDominator(tp: Type)(using Context): Type = {
196196

197-
/** a faster version of cs1 intersect cs2 */
198-
def intersect(cs1: List[ClassSymbol], cs2: List[ClassSymbol]): List[ClassSymbol] = {
199-
val cs2AsSet = new util.HashSet[ClassSymbol](128)
200-
cs2.foreach(cs2AsSet += _)
201-
cs1.filter(cs2AsSet.contains)
202-
}
197+
/** a faster version of cs1 intersect cs2 that treats bottom types correctly */
198+
def intersect(cs1: List[ClassSymbol], cs2: List[ClassSymbol]): List[ClassSymbol] =
199+
if cs1.head == defn.NothingClass then cs2
200+
else if cs2.head == defn.NothingClass then cs1
201+
else if cs1.head == defn.NullClass && !ctx.explicitNulls && cs2.head.derivesFrom(defn.ObjectClass) then cs2
202+
else if cs2.head == defn.NullClass && !ctx.explicitNulls && cs1.head.derivesFrom(defn.ObjectClass) then cs1
203+
else
204+
val cs2AsSet = new util.HashSet[ClassSymbol](128)
205+
cs2.foreach(cs2AsSet += _)
206+
cs1.filter(cs2AsSet.contains)
203207

204208
/** The minimal set of classes in `cs` which derive all other classes in `cs` */
205209
def dominators(cs: List[ClassSymbol], accu: List[ClassSymbol]): List[ClassSymbol] = (cs: @unchecked) match {

compiler/src/dotty/tools/dotc/core/Types.scala

+1-4
Original file line numberDiff line numberDiff line change
@@ -3148,10 +3148,7 @@ object Types {
31483148
/** Replace or type by the closest non-or type above it */
31493149
def join(using Context): Type = {
31503150
if (myJoinPeriod != ctx.period) {
3151-
myJoin =
3152-
if tp1 frozen_<:< tp2 then tp2
3153-
else if tp2 frozen_<:< tp1 then tp1
3154-
else TypeOps.orDominator(this)
3151+
myJoin = TypeOps.orDominator(this)
31553152
core.println(i"join of $this == $myJoin")
31563153
assert(myJoin != this)
31573154
myJoinPeriod = ctx.period

tests/pos/i11968a.scala

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
2+
class A {
3+
def get(): Int = 0
4+
}
5+
6+
class B extends A {}
7+
8+
class C extends A {}
9+
10+
def test1 = {
11+
val s: String | Null = ???
12+
val l = s.length
13+
14+
val a: A | Null = new A
15+
a.get()
16+
17+
val bc: B | C = new B
18+
bc.get()
19+
20+
val bcn: B | (C | Null) = new C
21+
bcn.get()
22+
23+
val bnc: (B | Null) | C = null
24+
bnc.get()
25+
26+
val abcn: A | B | C | Null = new A
27+
abcn.get()
28+
}
29+
30+
def test2 = {
31+
val s: String | Nothing = ???
32+
val l = s.length
33+
34+
val a: A | Nothing = new A
35+
a.get()
36+
37+
val bc: B | C = new B
38+
bc.get()
39+
40+
val bcn: B | (C | Nothing) = new C
41+
bcn.get()
42+
43+
val bnc: (B | Nothing) | C = new B
44+
bnc.get()
45+
46+
val abcn: A | B | C | Nothing = new A
47+
abcn.get()
48+
}

tests/pos/i11981.scala

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
object Main:
2+
class Null
3+
type Optional[A] = A | Null
4+
5+
val maybeInt: Optional[Int] = 1
6+
7+
// simplest typeclass
8+
trait TC[F[_]]
9+
10+
// given instances for our Optional and standard Option[_]
11+
given g1: TC[Optional] = ???
12+
given g2: TC[Option] = ???
13+
14+
def summonTC[F[_], A](f: F[A])(using TC[F]): Unit = ???
15+
16+
summonTC(Option(42)) // OK
17+
18+
summonTC[Optional, Int](maybeInt) // OK
19+
20+
summonTC(maybeInt)
21+

0 commit comments

Comments
 (0)