File tree Expand file tree Collapse file tree 2 files changed +17
-5
lines changed
extensions/ufunc/elementwise_functions
kernels/elementwise_functions Expand file tree Collapse file tree 2 files changed +17
-5
lines changed Original file line number Diff line number Diff line change @@ -62,10 +62,15 @@ namespace td_ns = dpctl::tensor::type_dispatch;
6262template <typename T>
6363struct OutputType
6464{
65- using value_type =
66- typename std::disjunction<td_ns::TypeMapResultEntry<T, float >,
67- td_ns::TypeMapResultEntry<T, double >,
68- td_ns::DefaultResultEntry<void >>::result_type;
65+ /* *
66+ * scipy>=1.16 assumes a pair 'e->d', but dpnp 'e->f' without an extra
67+ * kernel 'e->d' (when fp64 supported) to reduce memory footprint
68+ */
69+ using value_type = typename std::disjunction<
70+ td_ns::TypeMapResultEntry<T, sycl::half, float >,
71+ td_ns::TypeMapResultEntry<T, float >,
72+ td_ns::TypeMapResultEntry<T, double >,
73+ td_ns::DefaultResultEntry<void >>::result_type;
6974};
7075
7176using dpnp::kernels::erf::ErfFunctor;
Original file line number Diff line number Diff line change @@ -45,7 +45,14 @@ struct ErfFunctor
4545
4646 Tp operator ()(const argT &x) const
4747 {
48- return sycl::erf (x);
48+ if constexpr (std::is_same_v<argT, sycl::half> &&
49+ std::is_same_v<Tp, float >) {
50+ // cast sycl::half to float for accuracy reasons
51+ return sycl::erf (float (x));
52+ }
53+ else {
54+ return sycl::erf (x);
55+ }
4956 }
5057};
5158} // namespace dpnp::kernels::erf
You can’t perform that action at this time.
0 commit comments