Skip to content

Commit 89e57aa

Browse files
authored
Merge pull request #12562 from dotty-staging/implicit-sort
Sort implicits with a proper comparison function
2 parents 001e9cf + 171cea2 commit 89e57aa

File tree

2 files changed

+91
-27
lines changed

2 files changed

+91
-27
lines changed

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+5
Original file line numberDiff line numberDiff line change
@@ -1806,6 +1806,11 @@ object SymDenotations {
18061806
def baseClasses(implicit onBehalf: BaseData, ctx: Context): List[ClassSymbol] =
18071807
baseData._1
18081808

1809+
/** Like `baseClasses.length` but more efficient. */
1810+
def baseClassesLength(using BaseData, Context): Int =
1811+
// `+ 1` because the baseClassSet does not include the current class unlike baseClasses
1812+
baseClassSet.classIds.length + 1
1813+
18091814
/** A bitset that contains the superId's of all base classes */
18101815
private def baseClassSet(implicit onBehalf: BaseData, ctx: Context): BaseClassSet =
18111816
baseData._2

compiler/src/dotty/tools/dotc/typer/Implicits.scala

+86-27
Original file line numberDiff line numberDiff line change
@@ -1247,53 +1247,112 @@ trait Implicits:
12471247
|Consider using the scala.util.NotGiven class to implement similar functionality.""",
12481248
ctx.source.atSpan(span))
12491249

1250-
/** A relation that influences the order in which implicits are tried.
1250+
/** Compare the length of the baseClasses of two symbols (except for objects,
1251+
* where we use the length of the companion class instead if it's bigger).
1252+
*
1253+
* This relation is meant to approximate `Applications#compareOwner` while also
1254+
* inducing a total ordering: `compareOwner` returns `0` for unrelated symbols
1255+
* and therefore only induces a partial ordering, meaning it cannot be used
1256+
* as a sorting function (see `java.util.Comparator#compare`).
1257+
*/
1258+
def compareBaseClassesLength(sym1: Symbol, sym2: Symbol): Int =
1259+
def len(sym: Symbol) =
1260+
if sym.is(ModuleClass) && sym.companionClass.exists then
1261+
Math.max(sym.asClass.baseClassesLength, sym.companionClass.asClass.baseClassesLength)
1262+
else if sym.isClass then
1263+
sym.asClass.baseClassesLength
1264+
else
1265+
0
1266+
len(sym1) - len(sym2)
1267+
1268+
/** A relation that influences the order in which eligible implicits are tried.
1269+
*
12511270
* We prefer (in order of importance)
12521271
* 1. more deeply nested definitions
12531272
* 2. definitions with fewer implicit parameters
1254-
* 3. definitions in subclasses
1273+
* 3. definitions whose owner has more parents (see `compareBaseClassesLength`)
12551274
* The reason for (2) is that we want to fail fast if the search type
12561275
* is underconstrained. So we look for "small" goals first, because that
12571276
* will give an ambiguity quickly.
12581277
*/
1259-
def prefer(cand1: Candidate, cand2: Candidate): Boolean =
1260-
val level1 = cand1.level
1261-
val level2 = cand2.level
1262-
if level1 > level2 then return true
1263-
if level1 < level2 then return false
1264-
val sym1 = cand1.ref.symbol
1265-
val sym2 = cand2.ref.symbol
1278+
def compareEligibles(e1: Candidate, e2: Candidate): Int =
1279+
if e1 eq e2 then return 0
1280+
val cmpLevel = e1.level - e2.level
1281+
if cmpLevel != 0 then return -cmpLevel // 1.
1282+
val sym1 = e1.ref.symbol
1283+
val sym2 = e2.ref.symbol
12661284
val arity1 = sym1.info.firstParamTypes.length
12671285
val arity2 = sym2.info.firstParamTypes.length
1268-
if arity1 < arity2 then return true
1269-
if arity1 > arity2 then return false
1270-
compareOwner(sym1.owner, sym2.owner) == 1
1286+
val cmpArity = arity1 - arity2
1287+
if cmpArity != 0 then return cmpArity // 2.
1288+
val cmpBcs = compareBaseClassesLength(sym1.owner, sym2.owner)
1289+
-cmpBcs // 3.
12711290

1272-
/** Sort list of implicit references according to `prefer`.
1291+
/** Check if `ord` respects the contract of `Ordering`.
1292+
*
1293+
* More precisely, we check that its `compare` method respects the invariants listed
1294+
* in https://docs.oracle.com/javase/8/docs/api/java/util/Comparator.html#compare-T-T-
1295+
*/
1296+
def validateOrdering(ord: Ordering[Candidate]): Unit =
1297+
for
1298+
x <- eligible
1299+
y <- eligible
1300+
cmpXY = Integer.signum(ord.compare(x, y))
1301+
cmpYX = Integer.signum(ord.compare(y, x))
1302+
z <- eligible
1303+
cmpXZ = Integer.signum(ord.compare(x, z))
1304+
cmpYZ = Integer.signum(ord.compare(y, z))
1305+
do
1306+
def reportViolation(msg: String): Unit =
1307+
Console.err.println(s"Internal error: comparison function violated ${msg.stripMargin}")
1308+
def showCandidate(c: Candidate): String =
1309+
s"$c (${c.ref.symbol.showLocated})"
1310+
1311+
if cmpXY != -cmpYX then
1312+
reportViolation(
1313+
s"""signum(cmp(x, y)) == -signum(cmp(y, x)) given:
1314+
|x = ${showCandidate(x)}
1315+
|y = ${showCandidate(y)}
1316+
|cmpXY = $cmpXY
1317+
|cmpYX = $cmpYX""")
1318+
if cmpXY != 0 && cmpXY == cmpYZ && cmpXZ != cmpXY then
1319+
reportViolation(
1320+
s"""transitivity given:
1321+
|x = ${showCandidate(x)}
1322+
|y = ${showCandidate(y)}
1323+
|z = ${showCandidate(z)}
1324+
|cmpXY = $cmpXY
1325+
|cmpXZ = $cmpXZ
1326+
|cmpYZ = $cmpYZ""")
1327+
if cmpXY == 0 && cmpXZ != cmpYZ then
1328+
reportViolation(
1329+
s"""cmp(x, y) == 0 implies that signum(cmp(x, z)) == signum(cmp(y, z)) given:
1330+
|x = ${showCandidate(x)}
1331+
|y = ${showCandidate(y)}
1332+
|z = ${showCandidate(z)}
1333+
|cmpXY = $cmpXY
1334+
|cmpXZ = $cmpXZ
1335+
|cmpYZ = $cmpYZ""")
1336+
end validateOrdering
1337+
1338+
/** Sort list of implicit references according to `compareEligibles`.
12731339
* This is just an optimization that aims at reducing the average
12741340
* number of candidates to be tested.
12751341
*/
1276-
def sort(eligible: List[Candidate]) = eligible match {
1342+
def sort(eligible: List[Candidate]) = eligible match
12771343
case Nil => eligible
12781344
case e1 :: Nil => eligible
12791345
case e1 :: e2 :: Nil =>
1280-
if (prefer(e2, e1)) e2 :: e1 :: Nil
1346+
if compareEligibles(e2, e1) < 0 then e2 :: e1 :: Nil
12811347
else eligible
12821348
case _ =>
1283-
try eligible.sortWith(prefer)
1349+
val ord: Ordering[Candidate] = (a, b) => compareEligibles(a, b)
1350+
try eligible.sorted(using ord)
12841351
catch case ex: IllegalArgumentException =>
1285-
// diagnostic to see what went wrong
1286-
for
1287-
e1 <- eligible
1288-
e2 <- eligible
1289-
if prefer(e1, e2)
1290-
e3 <- eligible
1291-
if prefer(e2, e3) && !prefer(e1, e3)
1292-
do
1293-
val es = List(e1, e2, e3)
1294-
println(i"transitivity violated for $es%, %\n ${es.map(_.implicitRef.underlyingRef.symbol.showLocated)}%, %")
1352+
// This exception being thrown probably means that our comparison
1353+
// function is broken, check if that's the case
1354+
validateOrdering(ord)
12951355
throw ex
1296-
}
12971356

12981357
rank(sort(eligible), NoMatchingImplicitsFailure, Nil)
12991358
end searchImplicit

0 commit comments

Comments
 (0)