File tree Expand file tree Collapse file tree 2 files changed +18
-11
lines changed
dpctl/tensor/libtensor/include/kernels/elementwise_functions Expand file tree Collapse file tree 2 files changed +18
-11
lines changed Original file line number Diff line number Diff line change @@ -114,21 +114,22 @@ template <typename argT, typename resT> struct Expm1Functor
114114 }
115115
116116 // x, y finite numbers
117- realT cosY_val;
118- auto cosY_val_multi_ptr = sycl::address_space_cast<
119- sycl::access::address_space::private_space,
120- sycl::access::decorated::yes>(&cosY_val);
121- const realT sinY_val = sycl::sincos (y, cosY_val_multi_ptr);
122- const realT sinhalfY_val = std::sin (y / 2 );
117+ const realT cosY_val = std::cos (y);
118+ const realT sinY_val = (y == 0 ) ? y : std::sin (y);
119+ const realT sinhalfY_val = (y == 0 ) ? y : std::sin (y / 2 );
123120
124121 const realT res_re =
125122 std::expm1 (x) * cosY_val - 2 * sinhalfY_val * sinhalfY_val;
126- const realT res_im = std::exp (x) * sinY_val;
123+ realT res_im = std::exp (x) * sinY_val;
127124 return resT{res_re, res_im};
128125 }
129126 else {
130127 static_assert (std::is_floating_point_v<argT> ||
131128 std::is_same_v<argT, sycl::half>);
129+ static_assert (std::is_same_v<argT, resT>);
130+ if (in == 0 ) {
131+ return in;
132+ }
132133 return std::expm1 (in);
133134 }
134135 }
Original file line number Diff line number Diff line change @@ -81,11 +81,15 @@ template <typename argT, typename resT> struct SinFunctor
8181 */
8282 if (in_re_finite && in_im_finite) {
8383#ifdef USE_SYCL_FOR_COMPLEX_TYPES
84- return exprm_ns::sin (
84+ resT res = exprm_ns::sin (
8585 exprm_ns::complex <realT>(in)); // std::sin(in);
8686#else
87- return std::sin (in);
87+ resT res = std::sin (in);
8888#endif
89+ if (in_re == realT (0 )) {
90+ res.real (std::copysign (realT (0 ), in_re));
91+ }
92+ return res;
8993 }
9094
9195 /*
@@ -176,8 +180,10 @@ template <typename argT, typename resT> struct SinFunctor
176180 return resT{sinh_im, -sinh_re};
177181 }
178182 else {
179- static_assert (std::is_floating_point_v<argT> ||
180- std::is_same_v<argT, sycl::half>);
183+ static_assert (std::is_same_v<argT, resT>);
184+ if (in == 0 ) {
185+ return in;
186+ }
181187 return std::sin (in);
182188 }
183189 }
You can’t perform that action at this time.
0 commit comments