Skip to content

Commit b38fa9c

Browse files
authored
Merge pull request scala/scala#10259 from AminMal/issues/12706-take-drop-on-numeric-ranges
Fixes scala/bug#12706
2 parents 78ba5dc + 49bbaf2 commit b38fa9c

File tree

1 file changed

+71
-2
lines changed

1 file changed

+71
-2
lines changed

library/src/scala/collection/immutable/NumericRange.scala

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,75 @@ sealed class NumericRange[T](
132132
// are forgiving: therefore the checks are with the methods.
133133
private def locationAfterN(n: Int): T = start + (step * fromInt(n))
134134

135+
private def crossesTheEndAfterN(n: Int): Boolean = {
136+
// if we're sure that subtraction in the context of T won't overflow, we use this function
137+
// to calculate the length of the range
138+
def unsafeRangeLength(r: NumericRange[T]): T = {
139+
val diff = num.minus(r.end, r.start)
140+
val quotient = num.quot(diff, r.step)
141+
val remainder = num.rem(diff, r.step)
142+
if (!r.isInclusive && num.equiv(remainder, num.zero))
143+
num.max(quotient, num.zero)
144+
else
145+
num.max(num.plus(quotient, num.one), num.zero)
146+
}
147+
148+
// detects whether value can survive a bidirectional trip to -and then from- Int.
149+
def fitsInInteger(value: T): Boolean = num.equiv(num.fromInt(num.toInt(value)), value)
150+
151+
val stepIsInTheSameDirectionAsStartToEndVector =
152+
(num.gt(end, start) && num.gt(step, num.zero)) || (num.lt(end, start) && num.sign(step) == -num.one)
153+
154+
if (num.equiv(start, end) || n <= 0 || !stepIsInTheSameDirectionAsStartToEndVector) return n >= 1
155+
156+
val sameSign = num.equiv(num.sign(start), num.sign(end))
157+
158+
if (sameSign) { // subtraction is safe
159+
val len = unsafeRangeLength(this)
160+
if (fitsInInteger(len)) n >= num.toInt(len) else num.gteq(num.fromInt(n), len)
161+
} else {
162+
// split to two ranges, which subtraction is safe in both of them (around zero)
163+
val stepsRemainderToZero = num.rem(start, step)
164+
val walksOnZero = num.equiv(stepsRemainderToZero, num.zero)
165+
val closestToZero = if (walksOnZero) -step else stepsRemainderToZero
166+
167+
/*
168+
When splitting into two ranges, we should be super-careful about one of the sides hitting MinValue of T,
169+
so we take two steps smaller than zero to ensure unsafeRangeLength won't overflow (taking one step may overflow depending on the step).
170+
Same thing happens for MaxValue from zero, so we take one step further to ensure the safety of unsafeRangeLength.
171+
After performing such operation, there are some elements remaining in between and around zero,
172+
which their length is represented by carry.
173+
*/
174+
val (l: NumericRange[T], r: NumericRange[T], carry: Int) =
175+
if (num.lt(start, num.zero)) {
176+
if (walksOnZero) {
177+
val twoStepsAfterLargestNegativeNumber = num.plus(closestToZero, num.times(step, num.fromInt(2)))
178+
(NumericRange(start, closestToZero, step), copy(twoStepsAfterLargestNegativeNumber, end, step), 2)
179+
} else {
180+
(NumericRange(start, closestToZero, step), copy(num.plus(closestToZero, step), end, step), 1)
181+
}
182+
} else {
183+
if (walksOnZero) {
184+
val twoStepsAfterZero = num.times(step, num.fromInt(2))
185+
(copy(twoStepsAfterZero, end, step), NumericRange.inclusive(start, -step, step), 2)
186+
} else {
187+
val twoStepsAfterSmallestPositiveNumber = num.plus(closestToZero, num.times(step, num.fromInt(2)))
188+
(copy(twoStepsAfterSmallestPositiveNumber, end, step), NumericRange.inclusive(start, closestToZero, step), 2)
189+
}
190+
}
191+
192+
val leftLength = unsafeRangeLength(l)
193+
val rightLength = unsafeRangeLength(r)
194+
195+
// instead of `n >= rightLength + leftLength + curry` which may cause addition overflow,
196+
// this can be used `(n - leftLength - curry) >= rightLength` (Both in Int and T, depends on whether the lengths fit in Int)
197+
if (fitsInInteger(leftLength) && fitsInInteger(rightLength))
198+
n - num.toInt(leftLength) - carry >= num.toInt(rightLength)
199+
else
200+
num.gteq(num.minus(num.minus(num.fromInt(n), leftLength), num.fromInt(carry)), rightLength)
201+
}
202+
}
203+
135204
// When one drops everything. Can't ever have unchecked operations
136205
// like "end + 1" or "end - 1" because ranges involving Int.{ MinValue, MaxValue }
137206
// will overflow. This creates an exclusive range where start == end
@@ -140,13 +209,13 @@ sealed class NumericRange[T](
140209

141210
override def take(n: Int): NumericRange[T] = {
142211
if (n <= 0 || isEmpty) newEmptyRange(start)
143-
else if (n >= length) this
212+
else if (crossesTheEndAfterN(n)) this
144213
else new NumericRange.Inclusive(start, locationAfterN(n - 1), step)
145214
}
146215

147216
override def drop(n: Int): NumericRange[T] = {
148217
if (n <= 0 || isEmpty) this
149-
else if (n >= length) newEmptyRange(end)
218+
else if (crossesTheEndAfterN(n)) newEmptyRange(end)
150219
else copy(locationAfterN(n), end, step)
151220
}
152221

0 commit comments

Comments
 (0)