Skip to content

Commit 88f491f

Browse files
committed
more accurate sqrt function
1 parent 91fdc06 commit 88f491f

File tree

1 file changed

+87
-30
lines changed

1 file changed

+87
-30
lines changed

src/lib.rs

Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -281,40 +281,69 @@ impl<T: Float> Complex<T> {
281281
///
282282
/// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`.
283283
#[inline]
284-
pub fn sqrt(self) -> Self {
285-
if self.im.is_zero() {
286-
if self.re.is_sign_positive() {
287-
// simple positive real √r, and copy `im` for its sign
288-
Self::new(self.re.sqrt(), self.im)
289-
} else {
290-
// √(r e^(iπ)) = √r e^(iπ/2) = i√r
291-
// √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r
292-
let re = T::zero();
293-
let im = (-self.re).sqrt();
294-
if self.im.is_sign_positive() {
295-
Self::new(re, im)
296-
} else {
297-
Self::new(re, -im)
298-
}
299-
}
300-
} else if self.re.is_zero() {
301-
// √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2)
302-
// √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2)
303-
let one = T::one();
304-
let two = one + one;
305-
let x = (self.im.abs() / two).sqrt();
306-
if self.im.is_sign_positive() {
307-
Self::new(x, x)
284+
#[allow(clippy::eq_op)]
285+
pub fn sqrt(mut self) -> Self {
286+
// TODO: rounding for very tiny subnormal numbers isn't perfect yet so
287+
// the assert shown fails in the very worst case this leads to about
288+
// 10% accuracy loss (see example below). As the magnitude increase the
289+
// error quickly drops to basically zero.
290+
//
291+
// glibc handles that (but other implementations like musl and numpy do
292+
// not) by upscaling very small values. That upscaling (and particularly
293+
// it's reversal) are weird and hard to understand (and rely on mantissa
294+
// bit size which we can't get out of the trait). In general the glibc
295+
// implementation is ever so subtley different and I wouldn't want to
296+
// introduce bugs by trying to adapt the underflow handling.
297+
//
298+
// assert_eq!(
299+
// Complex64::new(5.212e-324, 5.212e-324).sqrt(),
300+
// Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162)
301+
// );
302+
303+
if self.re.is_zero() && self.im.is_zero() {
304+
// 0 +/- 0 i
305+
return Self::new(T::zero(), self.im);
306+
}
307+
if self.im.is_infinite() {
308+
// inf +/- inf i
309+
return Self::new(T::infinity(), self.im);
310+
}
311+
if self.re.is_nan() {
312+
// nan + nan i
313+
return Self::new(self.re, (self.im - self.im) / (self.im - self.im));
314+
}
315+
if self.re.is_infinite() {
316+
// √(inf +/- NaN i) = inf +/- NaN i
317+
// √(inf +/- x i) = inf +/- 0 i
318+
// √(-inf +/- NaN i) = NaN +/- inf i
319+
// √(-inf +/- x i) = 0 +/- inf i
320+
321+
if self.re.is_sign_negative() {
322+
return Self::new((self.im - self.im).abs(), self.re.copysign(self.im));
308323
} else {
309-
Self::new(x, -x)
324+
return Self::new(self.re, (self.im - self.im).copysign(self.im));
310325
}
326+
}
327+
let two = T::one() + T::one();
328+
let four = two + two;
329+
let overflow = T::max_value() / (T::one() + T::sqrt(two));
330+
let max_magnitude = self.re.abs().max(self.im.abs());
331+
let scale = max_magnitude >= overflow;
332+
if scale {
333+
self = self / four;
334+
}
335+
if self.re.is_sign_negative() {
336+
let tmp = ((-self.re + self.norm()) / two).sqrt();
337+
self.re = self.im.abs() / (two * tmp);
338+
self.im = tmp.copysign(self.im);
311339
} else {
312-
// formula: sqrt(r e^(it)) = sqrt(r) e^(it/2)
313-
let one = T::one();
314-
let two = one + one;
315-
let (r, theta) = self.to_polar();
316-
Self::from_polar(r.sqrt(), theta / two)
340+
self.re = ((self.re + self.norm()) / two).sqrt();
341+
self.im = self.im / (two * self.re);
317342
}
343+
if scale {
344+
self = self * two;
345+
}
346+
self
318347
}
319348

320349
/// Computes the principal value of the cube root of `self`.
@@ -2065,6 +2094,34 @@ pub(crate) mod test {
20652094
}
20662095
}
20672096

2097+
#[test]
2098+
fn test_sqrt_nan() {
2099+
assert!(close_naninf(
2100+
Complex64::new(f64::INFINITY, f64::NAN).sqrt(),
2101+
Complex64::new(f64::INFINITY, f64::NAN),
2102+
));
2103+
assert!(close_naninf(
2104+
Complex64::new(f64::NEG_INFINITY, -f64::NAN).sqrt(),
2105+
Complex64::new(f64::NAN, f64::NEG_INFINITY),
2106+
));
2107+
assert!(close_naninf(
2108+
Complex64::new(f64::NEG_INFINITY, f64::NAN).sqrt(),
2109+
Complex64::new(f64::NAN, f64::INFINITY),
2110+
));
2111+
for x in (-100..100).map(f64::from) {
2112+
// √(inf + x i) = inf + 0 i
2113+
assert!(close_naninf(
2114+
Complex64::new(f64::INFINITY, x).sqrt(),
2115+
Complex64::new(f64::INFINITY, 0.0.copysign(x)),
2116+
));
2117+
// √(-inf + x i) = 0 + inf i
2118+
assert!(close_naninf(
2119+
Complex64::new(f64::NEG_INFINITY, x).sqrt(),
2120+
Complex64::new(0.0, f64::INFINITY.copysign(x)),
2121+
));
2122+
}
2123+
}
2124+
20682125
#[test]
20692126
fn test_cbrt() {
20702127
assert!(close(_0_0i.cbrt(), _0_0i));

0 commit comments

Comments
 (0)