Skip to content

Commit dca0650

Browse files
Introduce custom math_utils::isfinite
This works-around a bug in std::isinfite recently introduced in SYCLOS where std::isfinite(-0.0) returns false for floating point types, including sycl::half, float, double.
1 parent e8772dc commit dca0650

File tree

13 files changed

+130
-45
lines changed

13 files changed

+130
-45
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstdint>
3030
#include <type_traits>
3131

32+
#include "utils/math_utils.hpp"
3233
#include "utils/offset_utils.hpp"
3334
#include "utils/type_dispatch.hpp"
3435
#include "utils/type_utils.hpp"
@@ -48,6 +49,7 @@ namespace atan2
4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
5051
namespace tu_ns = dpctl::tensor::type_utils;
52+
namespace mu_ns = dpctl::tensor::math_utils;
5153

5254
template <typename argT1, typename argT2, typename resT> struct Atan2Functor
5355
{
@@ -58,7 +60,7 @@ template <typename argT1, typename argT2, typename resT> struct Atan2Functor
5860
resT operator()(const argT1 &in1, const argT2 &in2)
5961
{
6062
if (std::isinf(in2) && !std::signbit(in2)) {
61-
if (std::isfinite(in1)) {
63+
if (mu_ns::isfinite(in1)) {
6264
return std::copysign(resT(0), in1);
6365
}
6466
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "kernels/elementwise_functions/common.hpp"
3333

34+
#include "utils/math_utils.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
@@ -47,6 +48,7 @@ namespace cos
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace mu_ns = dpctl::tensor::math_utils;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -73,8 +75,8 @@ template <typename argT, typename resT> struct CosFunctor
7375
realT const &in_re = std::real(in);
7476
realT const &in_im = std::imag(in);
7577

76-
const bool in_re_finite = std::isfinite(in_re);
77-
const bool in_im_finite = std::isfinite(in_im);
78+
const bool in_re_finite = mu_ns::isfinite(in_re);
79+
const bool in_im_finite = mu_ns::isfinite(in_im);
7880

7981
/*
8082
* Handle the nearly-non-exceptional cases where

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "kernels/elementwise_functions/common.hpp"
3333

34+
#include "utils/math_utils.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
@@ -47,6 +48,7 @@ namespace cosh
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace mu_ns = dpctl::tensor::math_utils;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -73,8 +75,8 @@ template <typename argT, typename resT> struct CoshFunctor
7375
const realT x = std::real(in);
7476
const realT y = std::imag(in);
7577

76-
const bool xfinite = std::isfinite(x);
77-
const bool yfinite = std::isfinite(y);
78+
const bool xfinite = mu_ns::isfinite(x);
79+
const bool yfinite = mu_ns::isfinite(y);
7880

7981
/*
8082
* Handle the nearly-non-exceptional cases where

dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "kernels/elementwise_functions/common.hpp"
3333

34+
#include "utils/math_utils.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
@@ -47,6 +48,7 @@ namespace exp
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace mu_ns = dpctl::tensor::math_utils;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -71,8 +73,8 @@ template <typename argT, typename resT> struct ExpFunctor
7173

7274
const realT x = std::real(in);
7375
const realT y = std::imag(in);
74-
if (std::isfinite(x)) {
75-
if (std::isfinite(y)) {
76+
if (mu_ns::isfinite(x)) {
77+
if (mu_ns::isfinite(y)) {
7678
return std::exp(in);
7779
}
7880
else {
@@ -93,7 +95,7 @@ template <typename argT, typename resT> struct ExpFunctor
9395
if (y == realT(0)) {
9496
return resT{x, y};
9597
}
96-
else if (std::isfinite(y)) {
98+
else if (mu_ns::isfinite(y)) {
9799
return resT{x * std::cos(y), x * std::sin(y)};
98100
}
99101
else {
@@ -102,7 +104,7 @@ template <typename argT, typename resT> struct ExpFunctor
102104
}
103105
}
104106
else { /* x is -inf */
105-
if (std::isfinite(y)) {
107+
if (mu_ns::isfinite(y)) {
106108
realT exp_x = std::exp(x);
107109
return resT{exp_x * std::cos(y), exp_x * std::sin(y)};
108110
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include "kernels/elementwise_functions/common.hpp"
3535

36+
#include "utils/math_utils.hpp"
3637
#include "utils/offset_utils.hpp"
3738
#include "utils/type_dispatch.hpp"
3839
#include "utils/type_utils.hpp"
@@ -49,6 +50,7 @@ namespace expm1
4950

5051
namespace py = pybind11;
5152
namespace td_ns = dpctl::tensor::type_dispatch;
53+
namespace mu_ns = dpctl::tensor::math_utils;
5254

5355
using dpctl::tensor::type_utils::is_complex;
5456

@@ -78,7 +80,7 @@ template <typename argT, typename resT> struct Expm1Functor
7880
if (std::isinf(x)) {
7981
if (x > realT(0)) {
8082
// positive infinity cases
81-
if (!std::isfinite(y)) {
83+
if (!mu_ns::isfinite(y)) {
8284
return resT{x, std::numeric_limits<realT>::quiet_NaN()};
8385
}
8486
else if (y == realT(0)) {
@@ -91,7 +93,7 @@ template <typename argT, typename resT> struct Expm1Functor
9193
}
9294
else {
9395
// negative infinity cases
94-
if (!std::isfinite(y)) {
96+
if (!mu_ns::isfinite(y)) {
9597
// copy sign of y to guarantee
9698
// conj(expm1(x)) == expm1(conj(x))
9799
return resT{realT(-1), std::copysign(realT(0), y)};
@@ -114,11 +116,8 @@ template <typename argT, typename resT> struct Expm1Functor
114116
}
115117

116118
// x, y finite numbers
117-
realT cosY_val;
118-
auto cosY_val_multi_ptr = sycl::address_space_cast<
119-
sycl::access::address_space::global_space,
120-
sycl::access::decorated::yes>(&cosY_val);
121-
const realT sinY_val = sycl::sincos(y, cosY_val_multi_ptr);
119+
const realT sinY_val = std::sin(y);
120+
const realT cosY_val = std::cos(y);
122121
const realT sinhalfY_val = std::sin(y / 2);
123122

124123
const realT res_re =

dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <cstdint>
3131
#include <type_traits>
3232

33+
#include "utils/math_utils.hpp"
3334
#include "utils/offset_utils.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
@@ -46,6 +47,7 @@ namespace isfinite
4647

4748
namespace py = pybind11;
4849
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace mu_ns = dpctl::tensor::math_utils;
4951

5052
using dpctl::tensor::type_utils::is_complex;
5153
using dpctl::tensor::type_utils::vec_cast;
@@ -68,33 +70,19 @@ template <typename argT, typename resT> struct IsFiniteFunctor
6870
resT operator()(const argT &in) const
6971
{
7072
if constexpr (is_complex<argT>::value) {
71-
const bool real_isfinite = std::isfinite(std::real(in));
72-
const bool imag_isfinite = std::isfinite(std::imag(in));
73+
const bool real_isfinite = mu_ns::isfinite(std::real(in));
74+
const bool imag_isfinite = mu_ns::isfinite(std::imag(in));
7375
return (real_isfinite && imag_isfinite);
7476
}
7577
else if constexpr (std::is_same<argT, bool>::value ||
7678
std::is_integral<argT>::value)
7779
{
7880
return constant_value;
7981
}
80-
else if constexpr (std::is_same_v<argT, sycl::half>) {
81-
return sycl::isfinite(in);
82-
}
8382
else {
84-
return std::isfinite(in);
83+
return mu_ns::isfinite(in);
8584
}
8685
}
87-
88-
template <int vec_sz>
89-
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
90-
{
91-
auto const &res_vec = sycl::isfinite(in);
92-
93-
using deducedT = typename std::remove_cv_t<
94-
std::remove_reference_t<decltype(res_vec)>>::element_type;
95-
96-
return vec_cast<bool, deducedT, vec_sz>(res_vec);
97-
}
9886
};
9987

10088
template <typename argT,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <limits>
3232
#include <type_traits>
3333

34+
#include "utils/math_utils.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
@@ -48,6 +49,7 @@ namespace logaddexp
4849
{
4950

5051
namespace py = pybind11;
52+
namespace mu_ns = dpctl::tensor::math_utils;
5153
namespace td_ns = dpctl::tensor::type_dispatch;
5254
namespace tu_ns = dpctl::tensor::type_utils;
5355

@@ -73,7 +75,7 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
7375

7476
#pragma unroll
7577
for (int i = 0; i < vec_sz; ++i) {
76-
if (std::isfinite(diff[i])) {
78+
if (mu_ns::isfinite(diff[i])) {
7779
res[i] = std::max<resT>(in1[i], in2[i]) +
7880
impl_finite<resT>(-std::abs(diff[i]));
7981
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "kernels/elementwise_functions/common.hpp"
3333

34+
#include "utils/math_utils.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
@@ -46,6 +47,7 @@ namespace sin
4647
{
4748

4849
namespace py = pybind11;
50+
namespace mu_ns = dpctl::tensor::math_utils;
4951
namespace td_ns = dpctl::tensor::type_dispatch;
5052

5153
using dpctl::tensor::type_utils::is_complex;
@@ -72,8 +74,8 @@ template <typename argT, typename resT> struct SinFunctor
7274
realT const &in_re = std::real(in);
7375
realT const &in_im = std::imag(in);
7476

75-
const bool in_re_finite = std::isfinite(in_re);
76-
const bool in_im_finite = std::isfinite(in_im);
77+
const bool in_re_finite = mu_ns::isfinite(in_re);
78+
const bool in_im_finite = mu_ns::isfinite(in_im);
7779
/*
7880
* Handle the nearly-non-exceptional cases where
7981
* real and imaginary parts of input are finite.

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "kernels/elementwise_functions/common.hpp"
3333

34+
#include "utils/math_utils.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
@@ -46,6 +47,7 @@ namespace sinh
4647
{
4748

4849
namespace py = pybind11;
50+
namespace mu_ns = dpctl::tensor::math_utils;
4951
namespace td_ns = dpctl::tensor::type_dispatch;
5052

5153
using dpctl::tensor::type_utils::is_complex;
@@ -71,8 +73,8 @@ template <typename argT, typename resT> struct SinhFunctor
7173
const realT x = std::real(in);
7274
const realT y = std::imag(in);
7375

74-
const bool xfinite = std::isfinite(x);
75-
const bool yfinite = std::isfinite(y);
76+
const bool xfinite = mu_ns::isfinite(x);
77+
const bool yfinite = mu_ns::isfinite(y);
7678

7779
/*
7880
* Handle the nearly-non-exceptional cases where

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
#include "kernels/elementwise_functions/common.hpp"
3636

37+
#include "utils/math_utils.hpp"
3738
#include "utils/offset_utils.hpp"
3839
#include "utils/type_dispatch.hpp"
3940
#include "utils/type_utils.hpp"
@@ -49,6 +50,7 @@ namespace sqrt
4950
{
5051

5152
namespace py = pybind11;
53+
namespace mu_ns = dpctl::tensor::math_utils;
5254
namespace td_ns = dpctl::tensor::type_dispatch;
5355

5456
using dpctl::tensor::type_utils::is_complex;
@@ -117,14 +119,16 @@ template <typename argT, typename resT> struct SqrtFunctor
117119
else if (std::isinf(x)) { // x is an infinity
118120
// y is either finite, or nan
119121
if (std::signbit(x)) { // x == -inf
120-
return {(std::isfinite(y) ? zero : y), std::copysign(p_inf, y)};
122+
return {(mu_ns::isfinite(y) ? zero : y),
123+
std::copysign(p_inf, y)};
121124
}
122125
else {
123-
return {p_inf, (std::isfinite(y) ? std::copysign(zero, y) : y)};
126+
return {p_inf,
127+
(mu_ns::isfinite(y) ? std::copysign(zero, y) : y)};
124128
}
125129
}
126130
else { // x is finite
127-
if (std::isfinite(y)) {
131+
if (mu_ns::isfinite(y)) {
128132
#ifdef USE_STD_SQRT_FOR_COMPLEX_TYPES
129133
return std::sqrt(z);
130134
#else

dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
#include "kernels/elementwise_functions/common.hpp"
3434

35+
#include "utils/math_utils.hpp"
3536
#include "utils/offset_utils.hpp"
3637
#include "utils/type_dispatch.hpp"
3738
#include "utils/type_utils.hpp"
@@ -47,6 +48,7 @@ namespace tan
4748
{
4849

4950
namespace py = pybind11;
51+
namespace mu_ns = dpctl::tensor::math_utils;
5052
namespace td_ns = dpctl::tensor::type_dispatch;
5153

5254
using dpctl::tensor::type_utils::is_complex;
@@ -94,7 +96,7 @@ template <typename argT, typename resT> struct TanFunctor
9496
* case is only needed to avoid a spurious invalid exception when
9597
* y is infinite.
9698
*/
97-
if (!std::isfinite(x)) {
99+
if (!mu_ns::isfinite(x)) {
98100
if (std::isnan(x)) {
99101
const realT tanh_re = x;
100102
const realT tanh_im = (y == realT(0) ? y : x * y);
@@ -111,7 +113,7 @@ template <typename argT, typename resT> struct TanFunctor
111113
* tanh(0 + i NAN) = 0 + i NaN
112114
* tanh(0 +- i Inf) = 0 + i NaN
113115
*/
114-
if (!std::isfinite(y)) {
116+
if (!mu_ns::isfinite(y)) {
115117
if (x == realT(0)) {
116118
return resT{q_nan, x};
117119
}

0 commit comments

Comments
 (0)