diff --git a/compiler/src/dotty/tools/backend/jvm/BTypesFromSymbols.scala b/compiler/src/dotty/tools/backend/jvm/BTypesFromSymbols.scala index cff63a06ff20..78dfd8e0a869 100644 --- a/compiler/src/dotty/tools/backend/jvm/BTypesFromSymbols.scala +++ b/compiler/src/dotty/tools/backend/jvm/BTypesFromSymbols.scala @@ -13,7 +13,6 @@ import dotty.tools.dotc.core.Phases._ import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.core.Phases.Phase import dotty.tools.dotc.transform.SymUtils._ -import dotty.tools.dotc.util.WeakHashSet /** * This class mainly contains the method classBTypeFromSymbol, which extracts the necessary @@ -49,7 +48,6 @@ class BTypesFromSymbols[I <: DottyBackendInterface](val int: I) extends BTypes { def newAnyRefMap[K <: AnyRef, V](): mutable.AnyRefMap[K, V] = new mutable.AnyRefMap[K, V]() def newWeakMap[K, V](): mutable.WeakHashMap[K, V] = new mutable.WeakHashMap[K, V]() def recordCache[T <: Clearable](cache: T): T = cache - def newWeakSet[K >: Null <: AnyRef](): WeakHashSet[K] = new WeakHashSet[K]() def newMap[K, V](): mutable.HashMap[K, V] = new mutable.HashMap[K, V]() def newSet[K](): mutable.Set[K] = new mutable.HashSet[K] } @@ -60,7 +58,6 @@ class BTypesFromSymbols[I <: DottyBackendInterface](val int: I) extends BTypes { def newWeakMap[K, V](): collection.mutable.WeakHashMap[K, V] def newMap[K, V](): collection.mutable.HashMap[K, V] def newSet[K](): collection.mutable.Set[K] - def newWeakSet[K >: Null <: AnyRef](): dotty.tools.dotc.util.WeakHashSet[K] def newAnyRefMap[K <: AnyRef, V](): collection.mutable.AnyRefMap[K, V] } diff --git a/compiler/src/dotty/tools/backend/jvm/DottyBackendInterface.scala b/compiler/src/dotty/tools/backend/jvm/DottyBackendInterface.scala index a5aa8abd1c7c..64e667d95b68 100644 --- a/compiler/src/dotty/tools/backend/jvm/DottyBackendInterface.scala +++ b/compiler/src/dotty/tools/backend/jvm/DottyBackendInterface.scala @@ -12,7 +12,6 @@ import scala.annotation.threadUnsafe import scala.collection.generic.Clearable import scala.collection.mutable import scala.reflect.ClassTag -import dotty.tools.dotc.util.WeakHashSet import dotty.tools.io.AbstractFile import scala.tools.asm.AnnotationVisitor import dotty.tools.dotc.core._ diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 6ef1856f5cfa..c6bb37ff9a85 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -559,7 +559,7 @@ object Contexts { def platform: Platform = base.platform def pendingUnderlying: util.HashSet[Type] = base.pendingUnderlying def uniqueNamedTypes: Uniques.NamedTypeUniques = base.uniqueNamedTypes - def uniques: util.HashSet[Type] = base.uniques + def uniques: util.WeakHashSet[Type] = base.uniques def initialize()(using Context): Unit = base.initialize() } diff --git a/compiler/src/dotty/tools/dotc/core/Uniques.scala b/compiler/src/dotty/tools/dotc/core/Uniques.scala index 5b1ae1a499e9..d706875f58dd 100644 --- a/compiler/src/dotty/tools/dotc/core/Uniques.scala +++ b/compiler/src/dotty/tools/dotc/core/Uniques.scala @@ -4,9 +4,11 @@ package core import Types._, Contexts._, util.Stats._, Hashable._, Names._ import config.Config import Decorators._ -import util.{HashSet, Stats} +import util.{WeakHashSet, Stats} +import WeakHashSet.Entry +import scala.annotation.tailrec -class Uniques extends HashSet[Type](Config.initialUniquesCapacity): +class Uniques extends WeakHashSet[Type](Config.initialUniquesCapacity): override def hash(x: Type): Int = x.hash override def isEqual(x: Type, y: Type) = x.eql(y) @@ -32,7 +34,7 @@ object Uniques: if tp.hash == NotCached then tp else ctx.uniques.put(tp).asInstanceOf[T] - final class NamedTypeUniques extends HashSet[NamedType](Config.initialUniquesCapacity * 4) with Hashable: + final class NamedTypeUniques extends WeakHashSet[NamedType](Config.initialUniquesCapacity * 4) with Hashable: override def hash(x: NamedType): Int = x.hash def enterIfNew(prefix: Type, designator: Designator, isTerm: Boolean)(using Context): NamedType = @@ -43,17 +45,25 @@ object Uniques: else new CachedTypeRef(prefix, designator, h) if h == NotCached then newType else + // Inlined from WeakHashSet#put Stats.record(statsItem("put")) - var idx = index(h) - var e = entryAt(idx) - while e != null do - if (e.prefix eq prefix) && (e.designator eq designator) && (e.isTerm == isTerm) then return e - idx = nextIndex(idx) - e = entryAt(idx) - addEntryAt(idx, newType) + removeStaleEntries() + val bucket = index(h) + val oldHead = table(bucket) + + @tailrec + def linkedListLoop(entry: Entry[NamedType]): NamedType = entry match + case null => addEntryAt(bucket, newType, h, oldHead) + case _ => + val e = entry.get + if e != null && (e.prefix eq prefix) && (e.designator eq designator) && (e.isTerm == isTerm) then e + else linkedListLoop(entry.tail) + + linkedListLoop(oldHead) + end if end NamedTypeUniques - final class AppliedUniques extends HashSet[AppliedType](Config.initialUniquesCapacity * 2) with Hashable: + final class AppliedUniques extends WeakHashSet[AppliedType](Config.initialUniquesCapacity * 2) with Hashable: override def hash(x: AppliedType): Int = x.hash def enterIfNew(tycon: Type, args: List[Type]): AppliedType = @@ -62,13 +72,21 @@ object Uniques: if monitored then recordCaching(h, classOf[CachedAppliedType]) if h == NotCached then newType else + // Inlined from WeakHashSet#put Stats.record(statsItem("put")) - var idx = index(h) - var e = entryAt(idx) - while e != null do - if (e.tycon eq tycon) && e.args.eqElements(args) then return e - idx = nextIndex(idx) - e = entryAt(idx) - addEntryAt(idx, newType) + removeStaleEntries() + val bucket = index(h) + val oldHead = table(bucket) + + @tailrec + def linkedListLoop(entry: Entry[AppliedType]): AppliedType = entry match + case null => addEntryAt(bucket, newType, h, oldHead) + case _ => + val e = entry.get + if e != null && (e.tycon eq tycon) && e.args.eqElements(args) then e + else linkedListLoop(entry.tail) + + linkedListLoop(oldHead) + end if end AppliedUniques end Uniques diff --git a/compiler/src/dotty/tools/dotc/util/MutableSet.scala b/compiler/src/dotty/tools/dotc/util/MutableSet.scala index bedb079f18ca..6e3ae7628eb6 100644 --- a/compiler/src/dotty/tools/dotc/util/MutableSet.scala +++ b/compiler/src/dotty/tools/dotc/util/MutableSet.scala @@ -8,7 +8,7 @@ abstract class MutableSet[T] extends ReadOnlySet[T]: def +=(x: T): Unit /** Like `+=` but return existing element equal to `x` of it exists, - * `x` itself otherwose. + * `x` itself otherwise. */ def put(x: T): T diff --git a/compiler/src/dotty/tools/dotc/util/WeakHashSet.scala b/compiler/src/dotty/tools/dotc/util/WeakHashSet.scala index 265f6e78cad2..3dc5761c0244 100644 --- a/compiler/src/dotty/tools/dotc/util/WeakHashSet.scala +++ b/compiler/src/dotty/tools/dotc/util/WeakHashSet.scala @@ -1,10 +1,10 @@ -/** Taken from the original implementation of WeakHashSet in scala-reflect +/** Adapted from the original implementation of WeakHashSet in scala-reflect */ package dotty.tools.dotc.util import java.lang.ref.{ReferenceQueue, WeakReference} -import scala.annotation.tailrec +import scala.annotation.{ constructorOnly, tailrec } import scala.collection.mutable /** @@ -17,12 +17,10 @@ import scala.collection.mutable * This set implementation is not in general thread safe without external concurrency control. However it behaves * properly when GC concurrently collects elements in this set. */ -final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) extends mutable.Set[A] { +abstract class WeakHashSet[A <: AnyRef](initialCapacity: Int = 8, loadFactor: Double = 0.5) extends MutableSet[A] { import WeakHashSet._ - def this() = this(initialCapacity = WeakHashSet.defaultInitialCapacity, loadFactor = WeakHashSet.defaultLoadFactor) - type This = WeakHashSet[A] /** @@ -30,12 +28,12 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e * the removeStaleEntries() method works through the queue to remove * stale entries from the table */ - private val queue = new ReferenceQueue[A] + protected val queue = new ReferenceQueue[A] /** * the number of elements in this set */ - private var count = 0 + protected var count = 0 /** * from a specified initial capacity compute the capacity we'll use as being the next @@ -52,40 +50,26 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e /** * the underlying table of entries which is an array of Entry linked lists */ - private var table = new Array[Entry[A]](computeCapacity) + protected var table = new Array[Entry[A]](computeCapacity) /** * the limit at which we'll increase the size of the hash table */ - private var threshold = computeThreshold + protected var threshold = computeThreshold private def computeThreshold: Int = (table.size * loadFactor).ceil.toInt - def get(elem: A): Option[A] = Option(findEntry(elem)) + protected def hash(key: A): Int + protected def isEqual(x: A, y: A): Boolean = x.equals(y) - /** - * find the bucket associated with an element's hash code - */ - private def bucketFor(hash: Int): Int = { - // spread the bits around to try to avoid accidental collisions using the - // same algorithm as java.util.HashMap - var h = hash - h ^= h >>> 20 ^ h >>> 12 - h ^= h >>> 7 ^ h >>> 4 - - // this is finding h % table.length, but takes advantage of the - // fact that table length is a power of 2, - // if you don't do bit flipping in your head, if table.length - // is binary 100000.. (with n 0s) then table.length - 1 - // is 1111.. with n 1's. - // In other words this masks on the last n bits in the hash - h & (table.length - 1) - } + /** Turn hashcode `x` into a table index */ + protected def index(x: Int): Int = x & (table.length - 1) /** * remove a single entry from a linked list in a given bucket */ private def remove(bucket: Int, prevEntry: Entry[A], entry: Entry[A]): Unit = { + Stats.record(statsItem("remove")) prevEntry match { case null => table(bucket) = entry.tail case _ => prevEntry.tail = entry.tail @@ -96,14 +80,14 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e /** * remove entries associated with elements that have been gc'ed */ - private def removeStaleEntries(): Unit = { + protected def removeStaleEntries(): Unit = { def poll(): Entry[A] = queue.poll().asInstanceOf[Entry[A]] @tailrec def queueLoop(): Unit = { val stale = poll() if (stale != null) { - val bucket = bucketFor(stale.hash) + val bucket = index(stale.hash) @tailrec def linkedListLoop(prevEntry: Entry[A], entry: Entry[A]): Unit = if (stale eq entry) remove(bucket, prevEntry, entry) @@ -121,7 +105,8 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e /** * Double the size of the internal table */ - private def resize(): Unit = { + protected def resize(): Unit = { + Stats.record(statsItem("resize")) val oldTable = table table = new Array[Entry[A]](oldTable.size * 2) threshold = computeThreshold @@ -132,7 +117,7 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e def linkedListLoop(entry: Entry[A]): Unit = entry match { case null => () case _ => - val bucket = bucketFor(entry.hash) + val bucket = index(entry.hash) val oldNext = entry.tail entry.tail = table(bucket) table(bucket) = entry @@ -145,103 +130,76 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e tableLoop(0) } - def contains(elem: A): Boolean = findEntry(elem) ne null - - // from scala.reflect.internal.Set, find an element or null if it isn't contained - def findEntry(elem: A): A = elem match { + def lookup(elem: A): A | Null = elem match { case null => throw new NullPointerException("WeakHashSet cannot hold nulls") case _ => + Stats.record(statsItem("lookup")) removeStaleEntries() - val hash = elem.hashCode - val bucket = bucketFor(hash) + val bucket = index(hash(elem)) @tailrec def linkedListLoop(entry: Entry[A]): A = entry match { case null => null.asInstanceOf[A] case _ => val entryElem = entry.get - if (elem.equals(entryElem)) entryElem + if (isEqual(elem, entryElem)) entryElem else linkedListLoop(entry.tail) } linkedListLoop(table(bucket)) } - // add an element to this set unless it's already in there and return the element - def findEntryOrUpdate(elem: A): A = elem match { + + protected def addEntryAt(bucket: Int, elem: A, elemHash: Int, oldHead: Entry[A]): A = { + Stats.record(statsItem("addEntryAt")) + table(bucket) = new Entry(elem, elemHash, oldHead, queue) + count += 1 + if (count > threshold) resize() + elem + } + + def put(elem: A): A = elem match { case null => throw new NullPointerException("WeakHashSet cannot hold nulls") case _ => + Stats.record(statsItem("put")) removeStaleEntries() - val hash = elem.hashCode - val bucket = bucketFor(hash) + val h = hash(elem) + val bucket = index(h) val oldHead = table(bucket) - def add() = { - table(bucket) = new Entry(elem, hash, oldHead, queue) - count += 1 - if (count > threshold) resize() - elem - } - @tailrec def linkedListLoop(entry: Entry[A]): A = entry match { - case null => add() + case null => addEntryAt(bucket, elem, h, oldHead) case _ => val entryElem = entry.get - if (elem.equals(entryElem)) entryElem + if (isEqual(elem, entryElem)) entryElem else linkedListLoop(entry.tail) } linkedListLoop(oldHead) } - // add an element to this set unless it's already in there and return this set - override def addOne(elem: A): this.type = elem match { - case null => throw new NullPointerException("WeakHashSet cannot hold nulls") - case _ => - removeStaleEntries() - val hash = elem.hashCode - val bucket = bucketFor(hash) - val oldHead = table(bucket) + def +=(elem: A): Unit = put(elem) - def add(): Unit = { - table(bucket) = new Entry(elem, hash, oldHead, queue) - count += 1 - if (count > threshold) resize() - } - - @tailrec - def linkedListLoop(entry: Entry[A]): Unit = entry match { - case null => add() - case _ if elem.equals(entry.get) => () - case _ => linkedListLoop(entry.tail) - } - - linkedListLoop(oldHead) - this - } - - // remove an element from this set and return this set - override def subtractOne(elem: A): this.type = elem match { - case null => this + def -=(elem: A): Unit = elem match { + case null => case _ => + Stats.record(statsItem("-=")) removeStaleEntries() - val bucket = bucketFor(elem.hashCode) + val bucket = index(hash(elem)) @tailrec def linkedListLoop(prevEntry: Entry[A], entry: Entry[A]): Unit = entry match { case null => () - case _ if elem.equals(entry.get) => remove(bucket, prevEntry, entry) + case _ if isEqual(elem, entry.get) => remove(bucket, prevEntry, entry) case _ => linkedListLoop(entry, entry.tail) } linkedListLoop(null, table(bucket)) - this } - // empty this set - override def clear(): Unit = { + def clear(): Unit = { table = new Array[Entry[A]](table.size) threshold = computeThreshold count = 0 @@ -251,21 +209,11 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e queueLoop() } - // true if this set is empty - override def empty: This = new WeakHashSet[A](initialCapacity, loadFactor) - - // the number of elements in this set - override def size: Int = { + def size: Int = { removeStaleEntries() count } - override def isEmpty: Boolean = size == 0 - override def foreach[U](f: A => U): Unit = iterator foreach f - - // It has the `()` because iterator runs `removeStaleEntries()` - override def toList(): List[A] = iterator.toList - // Iterator over all the elements in this set in no particular order override def iterator: Iterator[A] = { removeStaleEntries() @@ -318,6 +266,12 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e } } + protected def statsItem(op: String): String = { + val prefix = "WeakHashSet." + val suffix = getClass.getSimpleName + s"$prefix$op $suffix" + } + /** * Diagnostic information about the internals of this set. Not normally * needed by ordinary code, but may be useful for diagnosing performance problems @@ -338,9 +292,9 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e assert(entry.get != null, s"$entry had a null value indicated that gc activity was happening during diagnostic validation or that a null value was inserted") computedCount += 1 val cachedHash = entry.hash - val realHash = entry.get.hashCode + val realHash = hash(entry.get) assert(cachedHash == realHash, s"for $entry cached hash was $cachedHash but should have been $realHash") - val computedBucket = bucketFor(realHash) + val computedBucket = index(realHash) assert(computedBucket == bucket, s"for $entry the computed bucket was $computedBucket but should have been $bucket") entry = entry.tail @@ -386,11 +340,6 @@ object WeakHashSet { * A single entry in a WeakHashSet. It's a WeakReference plus a cached hash code and * a link to the next Entry in the same bucket */ - private class Entry[A](element: A, val hash:Int, var tail: Entry[A], queue: ReferenceQueue[A]) extends WeakReference[A](element, queue) - - private final val defaultInitialCapacity = 16 - private final val defaultLoadFactor = .75 + class Entry[A](@constructorOnly element: A, val hash:Int, var tail: Entry[A], @constructorOnly queue: ReferenceQueue[A]) extends WeakReference[A](element, queue) - def apply[A <: AnyRef](initialCapacity: Int = defaultInitialCapacity, loadFactor: Double = defaultLoadFactor): WeakHashSet[A] = - new WeakHashSet(initialCapacity, loadFactor) }