diff --git a/src/main/scala/scala/collection/decorators/MutableBitSetDecorator.scala b/src/main/scala/scala/collection/decorators/MutableBitSetDecorator.scala new file mode 100644 index 0000000..4b49ae5 --- /dev/null +++ b/src/main/scala/scala/collection/decorators/MutableBitSetDecorator.scala @@ -0,0 +1,120 @@ +package scala.collection.decorators + +import scala.collection.{BitSetOps, mutable} + +class MutableBitSetDecorator(protected val bs: mutable.BitSet) { + + import BitSetDecorator._ + import BitSetOps._ + + /** + * Updates this BitSet to the left shift of itself by the given shift distance. + * The shift distance may be negative, in which case this method performs a right shift. + * @param shiftBy shift distance, in bits + * @return the BitSet itself + */ + def <<=(shiftBy: Int): mutable.BitSet = { + + if (bs.nwords == 0 || bs.nwords == 1 && bs.word(0) == 0) () + else if (shiftBy > 0) shiftLeftInPlace(shiftBy) + else if (shiftBy < 0) shiftRightInPlace(-shiftBy) + + bs + } + + /** + * Updates this BitSet to the right shift of itself by the given shift distance. + * The shift distance may be negative, in which case this method performs a left shift. + * @param shiftBy shift distance, in bits + * @return the BitSet itself + */ + def >>=(shiftBy: Int): mutable.BitSet = { + + if (bs.nwords == 0 || bs.nwords == 1 && bs.word(0) == 0) () + else if (shiftBy > 0) shiftRightInPlace(shiftBy) + else if (shiftBy < 0) shiftLeftInPlace(-shiftBy) + + bs + } + + private def shiftLeftInPlace(shiftBy: Int): Unit = { + + val bitOffset = shiftBy & WordMask + val wordOffset = shiftBy >>> LogWL + + var significantWordCount = bs.nwords + while (significantWordCount > 0 && bs.word(significantWordCount - 1) == 0) { + significantWordCount -= 1 + } + + if (bitOffset == 0) { + val newSize = significantWordCount + wordOffset + require(newSize <= MaxSize) + ensureCapacity(newSize) + System.arraycopy(bs.elems, 0, bs.elems, wordOffset, significantWordCount) + } else { + val revBitOffset = WordLength - bitOffset + val extraBits = bs.elems(significantWordCount - 1) >>> revBitOffset + val extraWordCount = if (extraBits == 0) 0 else 1 + val newSize = significantWordCount + wordOffset + extraWordCount + require(newSize <= MaxSize) + ensureCapacity(newSize) + var i = significantWordCount - 1 + var previous = bs.elems(i) + while (i > 0) { + val current = bs.elems(i - 1) + bs.elems(i + wordOffset) = (current >>> revBitOffset) | (previous << bitOffset) + previous = current + i -= 1 + } + bs.elems(wordOffset) = previous << bitOffset + if (extraWordCount != 0) bs.elems(newSize - 1) = extraBits + } + java.util.Arrays.fill(bs.elems, 0, wordOffset, 0) + } + + private def shiftRightInPlace(shiftBy: Int): Unit = { + + val bitOffset = shiftBy & WordMask + + if (bitOffset == 0) { + val wordOffset = shiftBy >>> LogWL + val newSize = bs.nwords - wordOffset + if (newSize > 0) { + System.arraycopy(bs.elems, wordOffset, bs.elems, 0, newSize) + java.util.Arrays.fill(bs.elems, newSize, bs.nwords, 0) + } else bs.clear() + } else { + val wordOffset = (shiftBy >>> LogWL) + 1 + val extraBits = bs.elems(bs.nwords - 1) >>> bitOffset + val extraWordCount = if (extraBits == 0) 0 else 1 + val newSize = bs.nwords - wordOffset + extraWordCount + if (newSize > 0) { + val revBitOffset = WordLength - bitOffset + var previous = bs.elems(wordOffset - 1) + var i = wordOffset + while (i < bs.nwords) { + val current = bs.elems(i) + bs.elems(i - wordOffset) = (previous >>> bitOffset) | (current << revBitOffset) + previous = current + i += 1 + } + if (extraWordCount != 0) bs.elems(newSize - 1) = extraBits + java.util.Arrays.fill(bs.elems, newSize, bs.nwords, 0) + } else bs.clear() + } + } + + protected final def ensureCapacity(idx: Int): Unit = { + // Copied from mutable.BitSet.ensureCapacity (which is inaccessible from here). + require(idx < MaxSize) + if (idx >= bs.nwords) { + var newlen = bs.nwords + while (idx >= newlen) newlen = math.min(newlen * 2, MaxSize) + val elems1 = new Array[Long](newlen) + Array.copy(bs.elems, 0, elems1, 0, bs.nwords) + bs.elems = elems1 + } + } + +} diff --git a/src/main/scala/scala/collection/decorators/package.scala b/src/main/scala/scala/collection/decorators/package.scala index 77f2d0e..547891a 100644 --- a/src/main/scala/scala/collection/decorators/package.scala +++ b/src/main/scala/scala/collection/decorators/package.scala @@ -20,4 +20,7 @@ package object decorators { implicit def bitSetDecorator[C <: BitSet with BitSetOps[C]](bs: C): BitSetDecorator[C] = new BitSetDecorator(bs) + implicit def mutableBitSetDecorator(bs: mutable.BitSet): MutableBitSetDecorator = + new MutableBitSetDecorator(bs) + } diff --git a/src/test/scala/scala/collection/decorators/MutableBitSetDecoratorTest.scala b/src/test/scala/scala/collection/decorators/MutableBitSetDecoratorTest.scala new file mode 100644 index 0000000..cb8004a --- /dev/null +++ b/src/test/scala/scala/collection/decorators/MutableBitSetDecoratorTest.scala @@ -0,0 +1,107 @@ +package scala.collection.decorators + +import org.junit.{Assert, Test} + +import scala.collection.mutable.BitSet + +class MutableBitSetDecoratorTest { + + import Assert.{assertEquals, assertSame} + import BitSet.empty + + @Test + def shiftEmptyLeftInPlace(): Unit = { + for (shiftBy <- 0 to 128) { + val bs = empty + bs <<= shiftBy + assertEquals(empty, bs) + assertEquals(empty.nwords, bs.nwords) + } + } + + @Test + def shiftLowestBitLeftInPlace(): Unit = { + for (shiftBy <- 0 to 128) { + val bs = BitSet(0) + bs <<= shiftBy + assertEquals(BitSet(shiftBy), bs) + } + } + + @Test + def shiftNegativeLeftInPlace(): Unit = { + val bs = BitSet(1) + bs <<= -1 + assertEquals(BitSet(0), bs) + } + + @Test + def largeShiftLeftInPlace(): Unit = { + for (shiftBy <- 0 to 128) { + val bs = BitSet(0 to 300 by 5: _*) + val expected = bs.map(_ + shiftBy) + bs <<= shiftBy + assertEquals(expected, bs) + } + } + + @Test + def skipZeroWordsOnShiftLeftInPlace(): Unit = { + val bs = BitSet(5 * 64 - 1) + bs <<= 64 + assertEquals(BitSet(6 * 64 - 1), bs) + assertEquals(8, bs.nwords) + } + + @Test + def shiftEmptyRightInPlace(): Unit = { + for (shiftBy <- 0 to 128) { + val bs = empty + bs >>= shiftBy + assertEquals(empty, bs) + assertEquals(empty.nwords, bs.nwords) + } + } + + @Test + def shiftLowestBitRightInPlace(): Unit = { + val bs = BitSet(0) + bs >>= 0 + assertEquals(BitSet(0), bs) + + for (shiftBy <- 1 to 128) { + val bs = BitSet(0) + bs >>= shiftBy + assertEquals(empty, bs) + } + } + + @Test + def shiftToLowestBitRightInPlace(): Unit = { + for (shiftBy <- 0 to 128) { + val bs = BitSet(shiftBy) + bs >>= shiftBy + assertEquals(BitSet(0), bs) + } + } + + @Test + def shiftNegativeRightInPlace(): Unit = { + val bs = BitSet(0) + bs >>= -1 + assertEquals(BitSet(1), bs) + } + + @Test + def largeShiftRightInPlace(): Unit = { + for (shiftBy <- 0 to 128) { + val bs = BitSet(0 to 300 by 5: _*) + val expected = bs.collect { + case b if b >= shiftBy => b - shiftBy + } + bs >>= shiftBy + assertEquals(expected, bs) + } + } + +}