diff --git a/include/sycl_ext_complex.hpp b/include/sycl_ext_complex.hpp index d187022..c1fde3d 100644 --- a/include/sycl_ext_complex.hpp +++ b/include/sycl_ext_complex.hpp @@ -45,17 +45,7 @@ class complex complex& operator=(const complex&); template complex& operator= (const complex&); - template complex& operator+=(const complex&); - template complex& operator-=(const complex&); - template complex& operator*=(const complex&); - template complex& operator/=(const complex&); -}; - -template<> -class complex -{ -public: - typedef sycl::half value_type; + template complex& operator+=(const complex&);_SYCL_BEGIN_NAMESPACE_STD constexpr complex(sycl::half re = 0.0f, sycl::half im = 0.0f); explicit constexpr complex(const complex&); @@ -244,8 +234,10 @@ template complex tanh (const complex&); // clang-format on -#define _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD namespace sycl::ext::cplx { -#define _SYCL_EXT_CPLX_END_NAMESPACE_STD } +#define _SYCL_EXT_CPLX_BEGIN_NAMESPACE namespace sycl::ext::cplx { +#define _SYCL_BEGIN_NAMESPACE namespace sycl { +#define _SYCL_EXT_CPLX_END_NAMESPACE } +#define _SYCL_END_NAMESPACE } #define _SYCL_EXT_CPLX_INLINE_VISIBILITY \ inline __attribute__((__visibility__("hidden"), __always_inline__)) @@ -254,7 +246,7 @@ template complex tanh (const complex&); #include #include -_SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD +_SYCL_EXT_CPLX_BEGIN_NAMESPACE template struct __numeric_type { static void __test(...); @@ -314,8 +306,16 @@ template class __promote_imp<_A1, void, void, true> { template class __promote : public __promote_imp<_A1, _A2, _A3> {}; +// Forward declarations template class complex; +_SYCL_EXT_CPLX_END_NAMESPACE +_SYCL_BEGIN_NAMESPACE +template +class marray, NumElements>; +_SYCL_END_NAMESPACE +_SYCL_EXT_CPLX_BEGIN_NAMESPACE + template struct is_gencomplex : std::integral_constant &__y) { // 26.3.7 values: template ::value, - bool = std::is_floating_point<_Tp>::value> + bool = is_genfloat<_Tp>::value> struct __libcpp_complex_overload_traits {}; // Integral Types @@ -1091,7 +1091,7 @@ proj(const complex<_Tp> &__c) { template SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY typename std::enable_if< - std::is_floating_point<_Tp>::value, + is_genfloat<_Tp>::value, typename __libcpp_complex_overload_traits<_Tp>::_ComplexType>::type proj(_Tp __re) { if (sycl::isinf(__re)) @@ -1518,10 +1518,559 @@ operator<<(const sycl::stream &__ss, const complex<_Tp> &_x) { return __ss << "(" << _x.real() << "," << _x.imag() << ")"; } -_SYCL_EXT_CPLX_END_NAMESPACE_STD +_SYCL_EXT_CPLX_END_NAMESPACE + +// ---------------------------------------------------------------------------- +// -------------------------------complex-marray------------------------------- +// ---------------------------------------------------------------------------- + +_SYCL_BEGIN_NAMESPACE + +// marray of complex class specialisation +template +class marray, NumElements> { +private: + using DataT = sycl::ext::cplx::complex; + +public: + using value_type = DataT; + using reference = DataT &; + using const_reference = const DataT &; + using iterator = DataT *; + using const_iterator = const DataT *; + +private: + value_type MData[NumElements]; + +public: + marray() : MData{} {}; + + explicit constexpr marray(const DataT &arg) { + for (size_t i = 0; i < NumElements; ++i) + MData[i] = arg; + } + + template + constexpr marray(const ArgTN &... args) : MData{args...} {}; + + constexpr marray(const marray &rhs) = default; + constexpr marray(marray &&rhs) = default; + + // Available only when: NumElements == 1 + template > + operator DataT() const { + return MData[0]; + } + + static constexpr std::size_t size() noexcept { return NumElements; } + + // subscript operator + reference operator[](std::size_t index) { return MData[index]; } + const_reference operator[](std::size_t index) const { return MData[index]; } + + marray &operator=(const marray &rhs) = default; + marray &operator=(const DataT &rhs) { + for (std::size_t i = 0; i < NumElements; ++i) + MData[i] = rhs; + + return *this; + } + + // iterator functions + iterator begin() { return MData; } + const_iterator begin() const { return MData; } + + iterator end() { return MData + NumElements; } + const_iterator end() const { return MData + NumElements; } + + // OP is: +, -, *, / +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs[i]; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, const DataT &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const DataT &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs op rhs[i]; \ + \ + return rtn; \ + } + + OP(+) + OP(-) + OP(*) + OP(/) + +#undef OP + + // OP is: % + friend marray operator%(const marray &lhs, const marray &rhs) = delete; + friend marray operator%(const marray &lhs, const DataT &rhs) = delete; + friend marray operator%(const DataT &lhs, const marray &rhs) = delete; + + // OP is: +=, -=, *=, /= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs[i]; \ + \ + return lhs; \ + } \ + \ + friend marray &operator op(marray &lhs, const DataT &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs; \ + \ + return lhs; \ + } \ + friend marray &operator op(DataT &lhs, const marray &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs; \ + \ + return lhs; \ + } + + OP(+=) + OP(-=) + OP(*=) + OP(/=) + +#undef OP + + // OP is: %= + friend marray &operator%=(marray &lhs, const marray &rhs) = delete; + friend marray &operator%=(marray &lhs, const DataT &rhs) = delete; + friend marray &operator%=(DataT &lhs, const marray &rhs) = delete; + +// OP is: ++, -- +#define OP(op) \ + friend marray operator op(marray &lhs, int) = delete; \ + friend marray &operator op(marray &rhs) = delete; + + OP(++) + OP(--) + +#undef OP + + // OP is unary +, - + friend marray operator+(marray &rhs) = delete; + friend marray operator-(marray &rhs) = delete; + +// OP is: &, |, ^ +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, const DataT &rhs) = delete; + + OP(&) + OP(|) + OP(^) + +#undef OP + +// OP is: &=, |=, ^= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ + friend marray &operator op(marray &lhs, const DataT &rhs) = delete; \ + friend marray &operator op(DataT &lhs, const marray &rhs) = delete; + + OP(&=) + OP(|=) + OP(^=) + +#undef OP + +// OP is: &&, || +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, \ + const DataT &rhs) = delete; \ + friend marray operator op(const DataT &lhs, \ + const marray &rhs) = delete; + + OP(&&) + OP(||) + +#undef OP + +// OP is: <<, >> +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, const DataT &rhs) = delete; \ + friend marray operator op(const DataT &lhs, const marray &rhs) = delete; + + OP(<<) + OP(>>) + +#undef OP + +// OP is: <<=, >>= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ + friend marray &operator op(marray &lhs, const DataT &rhs) = delete; + + OP(<<=) + OP(>>=) + +#undef OP + + // OP is: ==, != +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs[i]; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, \ + const DataT &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const DataT &lhs, \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs op rhs[i]; \ + \ + return rtn; \ + } + + OP(==) + OP(!=) + +#undef OP + + // OP is: <, >, <=, >= +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, \ + const DataT &rhs) = delete; \ + friend marray operator op(const DataT &lhs, \ + const marray &rhs) = delete; + + OP(<); + OP(>); + OP(<=); + OP(>=); + +#undef OP + + friend marray operator~(const marray &v) = delete; + + friend marray operator!(const marray &v) = delete; +}; + +_SYCL_END_NAMESPACE + +_SYCL_EXT_CPLX_BEGIN_NAMESPACE + +// Make_complex_marray + +template +sycl::marray, NumElements> +make_complex_marray(const sycl::marray &real, + const sycl::marray &imag) { + sycl::marray, NumElements> rtn; + + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i].real(real[i]); + rtn[i].imag(imag[i]); + } + + return rtn; +} + +template +sycl::marray, NumElements> +make_complex_marray(const sycl::marray &real, const T &imag) { + sycl::marray, NumElements> rtn; + + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i].real(real[i]); + rtn[i].imag(imag); + } + + return rtn; +} + +template +sycl::marray, NumElements> +make_complex_marray(const T &real, const sycl::marray &imag) { + sycl::marray, NumElements> rtn; + + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i].real(real); + rtn[i].imag(imag[i]); + } + + return rtn; +} + +template +auto constexpr make_complex_marray( + const sycl::marray, NumElements> &cmplx, + std::integer_sequence int_seq) { + return sycl::marray, int_seq.size()>{cmplx[I]...}; +} + +template +auto constexpr make_complex_marray( + const sycl::marray &real, + const sycl::marray &imag, + std::integer_sequence int_seq) { + return sycl::marray, int_seq.size()>{ + complex(real[I], imag[I])...}; +} + +// Get + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray +get_real(const sycl::marray, NumElements> &input) { + sycl::marray rtn; + + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = input[i].real(); + + return rtn; +} + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray +get_imag(const sycl::marray, NumElements> &input) { + sycl::marray rtn; + + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = input[i].imag(); + + return rtn; +} + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY T get_real( + const sycl::marray, NumElements> &input, std::size_t index) { + return input[index].real(); +} + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY T get_imag( + const sycl::marray, NumElements> &input, std::size_t index) { + return input[index].imag(); +} + +template +auto constexpr get_real(const sycl::marray, NumElements> &input, + std::integer_sequence int_seq) { + return sycl::marray{input[I].real()...}; +} + +template +auto constexpr get_imag(const sycl::marray, NumElements> &input, + std::integer_sequence int_seq) { + return sycl::marray{input[I].imag()...}; +} + +// Set + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY void +set_real(sycl::marray, NumElements> &input, + const sycl::marray &values) { + for (std::size_t i = 0; i < NumElements; ++i) + input[i].real(values[i]); +} + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY void +set_imag(sycl::marray, NumElements> &input, + const sycl::marray &values) { + for (std::size_t i = 0; i < NumElements; ++i) + input[i].imag(values[i]); +} + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY void +set_real(sycl::marray, NumElements> &input, const T &value) { + for (std::size_t i = 0; i < NumElements; ++i) + input[i].real(value); +} + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY void +set_imag(sycl::marray, NumElements> &input, const T &value) { + for (std::size_t i = 0; i < NumElements; ++i) + input[i].imag(value); +} + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY void +set_real(sycl::marray, NumElements> &input, std::size_t index, + const T &value) { + input[index].real(value); +} + +template +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY void +set_imag(sycl::marray, NumElements> &input, std::size_t index, + const T &value) { + input[index].imag(value); +} + +// Math marray overloads + +#define MATH_OP_ONE_PARAM(math_func, rtn_type, arg_type) \ + template ::value || \ + is_gencomplex::value>> \ + SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY \ + sycl::marray \ + math_func(const sycl::marray &x) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i]); \ + \ + return rtn; \ + } + +#define MATH_OP_TWO_PARAM(math_func, rtn_type, arg_type1, arg_type2) \ + template ::value || \ + is_gencomplex::value>> \ + SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY \ + sycl::marray \ + math_func(const sycl::marray &x, \ + const sycl::marray &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i], y[i]); \ + \ + return rtn; \ + } \ + \ + template ::value || \ + is_gencomplex::value>> \ + SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY \ + sycl::marray \ + math_func(const sycl::marray &x, \ + const arg_type2 &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i], y); \ + \ + return rtn; \ + } \ + \ + template ::value || \ + is_gencomplex::value>> \ + SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY \ + sycl::marray \ + math_func(const arg_type1 &x, \ + const sycl::marray &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = math_func(x, y[i]); \ + \ + return rtn; \ + } + +MATH_OP_ONE_PARAM(abs, T, complex); +MATH_OP_ONE_PARAM(acos, complex, complex); +MATH_OP_ONE_PARAM(asin, complex, complex); +MATH_OP_ONE_PARAM(atan, complex, complex); +MATH_OP_ONE_PARAM(acosh, complex, complex); +MATH_OP_ONE_PARAM(asinh, complex, complex); +MATH_OP_ONE_PARAM(atanh, complex, complex); +MATH_OP_ONE_PARAM(arg, T, complex); +MATH_OP_ONE_PARAM(conj, complex, complex); +MATH_OP_ONE_PARAM(cos, complex, complex); +MATH_OP_ONE_PARAM(cosh, complex, complex); +MATH_OP_ONE_PARAM(exp, complex, complex); +MATH_OP_ONE_PARAM(log, complex, complex); +MATH_OP_ONE_PARAM(log10, complex, complex); +MATH_OP_ONE_PARAM(norm, T, complex); +MATH_OP_TWO_PARAM(pow, complex, complex, T); +MATH_OP_TWO_PARAM(pow, complex, complex, complex); +MATH_OP_TWO_PARAM(pow, complex, T, complex); +MATH_OP_ONE_PARAM(proj, complex, complex); +MATH_OP_ONE_PARAM(proj, complex, T); +MATH_OP_ONE_PARAM(sin, complex, complex); +MATH_OP_ONE_PARAM(sinh, complex, complex); +MATH_OP_ONE_PARAM(sqrt, complex, complex); +MATH_OP_ONE_PARAM(tan, complex, complex); +MATH_OP_ONE_PARAM(tanh, complex, complex); + +// Special definition as polar requires default argument + +template ::value>> +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const sycl::marray &rho, + const sycl::marray &theta) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho[i], theta[i]); + + return rtn; +} + +template ::value>> +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const sycl::marray &rho, const T &theta = 0) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho[i], theta); + + return rtn; +} + +template ::value>> +SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const T &rho, const sycl::marray &theta) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho, theta[i]); + + return rtn; +} + +#undef MATH_OP_ONE_PARAM +#undef MATH_OP_TWO_PARAM + +_SYCL_EXT_CPLX_END_NAMESPACE -#undef _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD -#undef _SYCL_EXT_CPLX_END_NAMESPACE_STD +#undef _SYCL_EXT_CPLX_BEGIN_NAMESPACE +#undef _SYCL_BEGIN_NAMESPACE +#undef _SYCL_EXT_CPLX_END_NAMESPACE +#undef _SYCL_END_NAMESPACE #undef _SYCL_EXT_CPLX_INLINE_VISIBILITY #endif // _SYCL_EXT_CPLX_COMPLEX diff --git a/tests/abs_complex.cpp b/tests/abs_complex.cpp index 11243e5..187158a 100644 --- a/tests/abs_complex.cpp +++ b/tests/abs_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex abs", "[abs]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,48 @@ TEMPLATE_TEST_CASE("Test complex abs", "[abs]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex abs", "[abs]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::abs(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::abs(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::abs(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/acos_complex.cpp b/tests/acos_complex.cpp index 95d7209..6ca7a9a 100644 --- a/tests/acos_complex.cpp +++ b/tests/acos_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex acos", "[acos]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,3 +60,90 @@ TEMPLATE_TEST_CASE("Test complex acos", "[acos]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test(sycl::queue &Q, test_marray init_re, + test_marray init_im, bool is_error_checking) { + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::acos(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::acos(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::cos(sycl::ext::cplx::acos(cplx_input)); + }).wait(); + } + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 4); + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::acos(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::cos(sycl::ext::cplx::acos(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex acos (check error: false)", + "[acos]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{1.0, 4.42, -3, 4.0}); + auto init_im = GENERATE(test_marray{1.0, 2.02, 3.5, -4.0}); + + test(Q, init_re, init_im, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex acos (check error: true)", "[acos]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 2.02, inf_val, inf_val, 2.02, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + inf_val, 4.42, nan_val, 4.42, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + + test(Q, init_re, init_im, true); +} diff --git a/tests/acosh_complex.cpp b/tests/acosh_complex.cpp index ed9eafd..e4eef59 100644 --- a/tests/acosh_complex.cpp +++ b/tests/acosh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex acosh", "[acosh]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,4 +60,92 @@ TEMPLATE_TEST_CASE("Test complex acosh", "[acosh]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test(sycl::queue &Q, test_marray init_re, + test_marray init_im, bool is_error_checking) { + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::acosh(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::acosh(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::cosh(sycl::ext::cplx::acosh(cplx_input)); + }).wait(); + } + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 4); + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::acosh(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::cosh(sycl::ext::cplx::acosh(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex acosh (check error: false)", + "[acosh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{1.0, 4.42, -3, 4.0}); + auto init_im = GENERATE(test_marray{1.0, 2.02, 3.5, -4.0}); + + test(Q, init_re, init_im, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex acosh (check error: true)", + "[acosh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 2.02, inf_val, inf_val, 2.02, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + inf_val, 4.42, nan_val, 4.42, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + + test(Q, init_re, init_im, true); +} diff --git a/tests/arg_complex.cpp b/tests/arg_complex.cpp index efb2a25..3f0c14d 100644 --- a/tests/arg_complex.cpp +++ b/tests/arg_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex arg", "[arg]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,48 @@ TEMPLATE_TEST_CASE("Test complex arg", "[arg]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex arg", "[arg]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::arg(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::arg(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::arg(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/asin_complex.cpp b/tests/asin_complex.cpp index 185d634..e7a51d7 100644 --- a/tests/asin_complex.cpp +++ b/tests/asin_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex asin", "[asin]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -55,4 +59,91 @@ TEMPLATE_TEST_CASE("Test complex asin", "[asin]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test(sycl::queue &Q, test_marray init_re, + test_marray init_im, bool is_error_checking) { + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::asin(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::asin(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::sin(sycl::ext::cplx::asin(cplx_input)); + }).wait(); + } + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 6); + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::asin(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::sin(sycl::ext::cplx::asin(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 6); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex asin (check error: false)", + "[asin]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{1.0, 4.42, -3, 4.0}); + auto init_im = GENERATE(test_marray{1.0, 2.02, 3.5, -4.0}); + + test(Q, init_re, init_im, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex asin (check error: true)", "[asin]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 2.02, inf_val, inf_val, 2.02, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + inf_val, 4.42, nan_val, 4.42, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + + test(Q, init_re, init_im, true); +} diff --git a/tests/asinh_complex.cpp b/tests/asinh_complex.cpp index cc1babc..1228679 100644 --- a/tests/asinh_complex.cpp +++ b/tests/asinh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex asinh", "[asinh]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,4 +60,92 @@ TEMPLATE_TEST_CASE("Test complex asinh", "[asinh]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test(sycl::queue &Q, test_marray init_re, + test_marray init_im, bool is_error_checking) { + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::asinh(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::asinh(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::sinh(sycl::ext::cplx::asinh(cplx_input)); + }).wait(); + } + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 4); + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::asinh(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::sinh(sycl::ext::cplx::asinh(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 3); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex asinh (check error: false)", + "[asinh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{1.0, 4.42, -3, 4.0}); + auto init_im = GENERATE(test_marray{1.0, 2.02, 3.5, -4.0}); + + test(Q, init_re, init_im, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex asinh (check error: true)", + "[asinh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 2.02, inf_val, inf_val, 2.02, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + inf_val, 4.42, nan_val, 4.42, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + + test(Q, init_re, init_im, true); +} diff --git a/tests/atan_complex.cpp b/tests/atan_complex.cpp index e0283b1..3e803a6 100644 --- a/tests/atan_complex.cpp +++ b/tests/atan_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex atan", "[atan]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -55,4 +59,90 @@ TEMPLATE_TEST_CASE("Test complex atan", "[atan]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test(sycl::queue &Q, test_marray init_re, + test_marray init_im, bool is_error_checking) { + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::atan(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::atan(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::tan(sycl::ext::cplx::atan(cplx_input)); + }).wait(); + } + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::atan(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::tan(sycl::ext::cplx::atan(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex atan (check error: false)", "[atan]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{1.0, 4.42, -3, 4.0}); + auto init_im = GENERATE(test_marray{1.0, 2.02, 3.5, -4.0}); + + test(Q, init_re, init_im, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex atan (check error: true)", "[atan]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 2.02, inf_val, inf_val, 2.02, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + inf_val, 4.42, nan_val, 4.42, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + + test(Q, init_re, init_im, true); +} diff --git a/tests/atanh_complex.cpp b/tests/atanh_complex.cpp index f20bb38..4be1d36 100644 --- a/tests/atanh_complex.cpp +++ b/tests/atanh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex atanh", "[atanh]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,4 +60,93 @@ TEMPLATE_TEST_CASE("Test complex atanh", "[atanh]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test(sycl::queue &Q, test_marray init_re, + test_marray init_im, bool is_error_checking) { + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::atanh(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::atanh(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::tanh(sycl::ext::cplx::atanh(cplx_input)); + }).wait(); + } + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::atanh(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::tanh(sycl::ext::cplx::atanh(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex atanh (check error: false)", + "[atanh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{1.0, 4.42, -3, 4.0}); + auto init_im = GENERATE(test_marray{1.0, 2.02, 3.5, -4.0}); + + test(Q, init_re, init_im, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex atanh (check error: true)", + "[atanh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 2.02, inf_val, inf_val, 2.02, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + inf_val, 4.42, nan_val, 4.42, nan_val, nan_val, inf_val, + nan_val, inf_val, nan_val}); + + test(Q, init_re, init_im, true); +} diff --git a/tests/cos_complex.cpp b/tests/cos_complex.cpp index c5c5d55..f50d1e4 100644 --- a/tests/cos_complex.cpp +++ b/tests/cos_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex cos", "[cos]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,49 @@ TEMPLATE_TEST_CASE("Test complex cos", "[cos]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex cos", "[cos]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::cos(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::cos(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::cos(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/cosh_complex.cpp b/tests/cosh_complex.cpp index a4a7ee9..9db933f 100644 --- a/tests/cosh_complex.cpp +++ b/tests/cosh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex cosh", "[cosh]", double, float, sycl::half) { using T = TestType; @@ -37,4 +41,50 @@ TEMPLATE_TEST_CASE("Test complex cosh", "[cosh]", double, float, sycl::half) { check_results(cplx_out[0], std_out); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex cosh", "[cosh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::cosh(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::cosh(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::cosh(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/exp_complex.cpp b/tests/exp_complex.cpp index 19bea14..427017b 100644 --- a/tests/exp_complex.cpp +++ b/tests/exp_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex exp", "[exp]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,49 @@ TEMPLATE_TEST_CASE("Test complex exp", "[exp]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex exp", "[exp]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::exp(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::exp(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::exp(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/log10_complex.cpp b/tests/log10_complex.cpp index 4affaa2..a076b74 100644 --- a/tests/log10_complex.cpp +++ b/tests/log10_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex log10", "[log10]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,49 @@ TEMPLATE_TEST_CASE("Test complex log10", "[log10]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex log10", "[log10]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::log10(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::log10(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::log10(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/log_complex.cpp b/tests/log_complex.cpp index 8907838..f4465ba 100644 --- a/tests/log_complex.cpp +++ b/tests/log_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex log", "[log]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,49 @@ TEMPLATE_TEST_CASE("Test complex log", "[log]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex log", "[log]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::log(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::log(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::log(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/norm_complex.cpp b/tests/norm_complex.cpp index 1778425..918e665 100644 --- a/tests/norm_complex.cpp +++ b/tests/norm_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex norm", "[norm]", double, float, sycl::half) { using T = TestType; @@ -37,3 +41,48 @@ TEMPLATE_TEST_CASE("Test complex norm", "[norm]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex norm", "[norm]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE( + test_marray{1.0, 4.42, -3, 4.0, 2.02, inf_val, + inf_val, 2.02, nan_val, nan_val}); + auto init_im = GENERATE( + test_marray{1.0, 2.02, 3.5, -4.0, inf_val, 4.42, + nan_val, 4.42, nan_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::norm(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::norm(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::norm(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/polar_complex.cpp b/tests/polar_complex.cpp index 4d2198f..489a76a 100644 --- a/tests/polar_complex.cpp +++ b/tests/polar_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex polar", "[polar]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -34,3 +38,45 @@ TEMPLATE_TEST_CASE("Test complex polar", "[polar]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex polar", "[polar]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // Test cases + auto init_rho = GENERATE(test_marray{1.0, 4.42, 3, 3.14}); + auto init_theta = + GENERATE(test_marray{1.0, 2.02, 3.5, -3.14}); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + sycl::marray rho = init_rho.get(); + sycl::marray theta = init_theta.get(); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::polar(rho[i], theta[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::polar(rho, theta); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::polar(rho, theta); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/pow_complex.cpp b/tests/pow_complex.cpp index aa1bf9f..149a374 100644 --- a/tests/pow_complex.cpp +++ b/tests/pow_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex pow cplx-cplx overload", "[pow]", double, float, sycl::half) { using T = TestType; @@ -137,3 +141,152 @@ TEMPLATE_TEST_CASE("Test complex pow deci-cplx overload", "[pow]", double, sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex pow cplx-cplx overload", "[pow]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re1 = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im1 = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_re2 = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im2 = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in1 = init_std_complex(init_re1.get(), init_im1.get()); + auto std_in2 = init_std_complex(init_re2.get(), init_im2.get()); + sycl::marray, NumElements> cplx_input1 = + sycl::ext::cplx::make_complex_marray(init_re1.get(), init_im1.get()); + sycl::marray, NumElements> cplx_input2 = + sycl::ext::cplx::make_complex_marray(init_re2.get(), init_im2.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::pow(std_in1[i], std_in2[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex pow cplx-deci overload", "[pow]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re1 = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im1 = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_re2 = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in1 = init_std_complex(init_re1.get(), init_im1.get()); + auto std_in2 = init_deci(init_re2.get()); + sycl::marray, NumElements> cplx_input1 = + sycl::ext::cplx::make_complex_marray(init_re1.get(), init_im1.get()); + sycl::marray cplx_input2 = init_re2.get(); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::pow(std_in1[i], std_in2[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex pow deci-cplx overload", "[pow]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re1 = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_re2 = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im2 = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in1 = init_deci(init_re1.get()); + auto std_in2 = init_std_complex(init_re2.get(), init_im2.get()); + sycl::marray cplx_input1 = init_re1.get(); + sycl::marray, NumElements> cplx_input2 = + sycl::ext::cplx::make_complex_marray(init_re2.get(), init_im2.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::pow(std_in1[i], std_in2[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/proj_complex.cpp b/tests/proj_complex.cpp new file mode 100644 index 0000000..09a7870 --- /dev/null +++ b/tests/proj_complex.cpp @@ -0,0 +1,162 @@ +#include "test_helper.hpp" + +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE("Test complex proj cplx overload", "[proj]", double, float, + sycl::half) { + using T = TestType; + + sycl::queue Q; + + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; + + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); + + // Get std::complex output + std_out = std::proj(std_in); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + cplx_out[0] = sycl::ext::cplx::proj(cplx_input); + }).wait(); + } + + check_results(cplx_out[0], std_out); + + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::proj(cplx_input); + + check_results(cplx_out[0], std_out); + + sycl::free(cplx_out, Q); +} + +TEMPLATE_TEST_CASE("Test complex proj deci overload", "[proj]", double, float, + sycl::half) { + using T = TestType; + + sycl::queue Q; + + T input = GENERATE(4.42, inf_val, 4.42, inf_val, nan_val, 4.42, + nan_val, nan_val, inf_val); + + auto std_in = init_deci(input); + + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); + + // Get std::complex output + std_out = std::proj(std_in); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + cplx_out[0] = sycl::ext::cplx::proj(input); + }).wait(); + } + + check_results(cplx_out[0], std_out); + + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::proj(input); + + check_results(cplx_out[0], std_out); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex proj cplx overload", "[proj]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::proj(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::proj(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::proj(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex proj deci overload", "[proj]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto input = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_deci(input.get()); + sycl::marray cplx_input = input.get(); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::proj(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::proj(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::proj(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/sin_complex.cpp b/tests/sin_complex.cpp index b6c8c38..4817e50 100644 --- a/tests/sin_complex.cpp +++ b/tests/sin_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex sin", "[sin]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,49 @@ TEMPLATE_TEST_CASE("Test complex sin", "[sin]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex sin", "[sin]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::sin(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::sin(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::sin(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/sinh_complex.cpp b/tests/sinh_complex.cpp index a11e1c2..4d78dd1 100644 --- a/tests/sinh_complex.cpp +++ b/tests/sinh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex sinh", "[sinh]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,49 @@ TEMPLATE_TEST_CASE("Test complex sinh", "[sinh]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex sinh", "[sinh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::sinh(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::sinh(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::sinh(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/sqrt_complex.cpp b/tests/sqrt_complex.cpp index 32601bb..69c811d 100644 --- a/tests/sqrt_complex.cpp +++ b/tests/sqrt_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex sqrt", "[sqrt]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,49 @@ TEMPLATE_TEST_CASE("Test complex sqrt", "[sqrt]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex sqrt", "[sqrt]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::sqrt(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::sqrt(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::sqrt(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/tan_complex.cpp b/tests/tan_complex.cpp index cc46157..243648f 100644 --- a/tests/tan_complex.cpp +++ b/tests/tan_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex tan", "[tan]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,49 @@ TEMPLATE_TEST_CASE("Test complex tan", "[tan]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex tan", "[tan]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::tan(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::tan(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::tan(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/tanh_complex.cpp b/tests/tanh_complex.cpp index bfe934d..98d8d01 100644 --- a/tests/tanh_complex.cpp +++ b/tests/tanh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex tanh", "[tanh]", double, float, sycl::half) { using T = TestType; @@ -37,4 +41,50 @@ TEMPLATE_TEST_CASE("Test complex tanh", "[tanh]", double, float, sycl::half) { check_results(cplx_out[0], std_out); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex tanh", "[tanh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // Test cases + auto init_re = GENERATE(test_marray{ + 1.0, 4.42, -3, 4.0, 2.02, inf_val, inf_val, 2.02, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + auto init_im = GENERATE(test_marray{ + 1.0, 2.02, 3.5, -4.0, inf_val, 4.42, nan_val, 4.42, nan_val, + nan_val, inf_val, nan_val, inf_val, nan_val}); + + auto std_in = init_std_complex(init_re.get(), init_im.get()); + sycl::marray, NumElements> cplx_input = + sycl::ext::cplx::make_complex_marray(init_re.get(), init_im.get()); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::tanh(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::tanh(cplx_input); + }).wait(); + } + + check_results(*cplx_out, std_out); + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::tanh(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/test_complex_types.cpp b/tests/test_complex_types.cpp index ccaf774..7acef09 100644 --- a/tests/test_complex_types.cpp +++ b/tests/test_complex_types.cpp @@ -2,6 +2,10 @@ using namespace sycl::ext::cplx; +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + // Test checks user interface return types // Compile time only tests, will fail during compilation due to static asserts @@ -9,6 +13,11 @@ using namespace sycl::ext::cplx; #define TEST_MATH_OP_TYPE(test_name, label, op) \ TEMPLATE_TEST_CASE(test_name, label, double, float, sycl::half) { \ \ + static_assert( \ + std::is_same_v, \ + decltype(std::declval>() \ + op std::declval>())>); \ + \ static_assert(std::is_same_v, \ decltype(std::declval>() \ op std::declval())>); \ diff --git a/tests/test_gencomplex.cpp b/tests/test_gencomplex.cpp index 64b0bd6..0474911 100644 --- a/tests/test_gencomplex.cpp +++ b/tests/test_gencomplex.cpp @@ -2,6 +2,10 @@ using namespace sycl::ext::cplx; +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + // Check is_gencomplex TEST_CASE("Test is_gencomplex", "[gencomplex]") { static_assert(is_gencomplex>::value == true); @@ -15,3 +19,17 @@ TEST_CASE("Test is_gencomplex", "[gencomplex]") { static_assert(is_gencomplex>::value == false); static_assert(is_gencomplex>::value == false); } + +// Check is_genfloat +TEST_CASE("Test is_genfloat", "[genfloat]") { + static_assert(is_genfloat::value == true); + static_assert(is_genfloat::value == true); + static_assert(is_genfloat::value == true); + + static_assert(is_genfloat::value == false); + static_assert(is_genfloat::value == false); + static_assert(is_genfloat::value == false); + static_assert(is_genfloat::value == false); + static_assert(is_genfloat::value == false); + static_assert(is_genfloat::value == false); +} diff --git a/tests/test_helper.hpp b/tests/test_helper.hpp index 18cf0e3..03ec3fe 100644 --- a/tests/test_helper.hpp +++ b/tests/test_helper.hpp @@ -10,7 +10,9 @@ #define SYCL_CPLX_TOL_ULP 5 // Helpers for check if type is supported -template inline bool is_type_supported(sycl::queue &Q) { return false; } +template inline bool is_type_supported(sycl::queue &Q) { + return false; +} template <> inline bool is_type_supported(sycl::queue &Q) { return Q.get_device().has(sycl::aspect::fp64); @@ -46,35 +48,33 @@ template struct cmplx { T im; }; -// Helper for testing each decimal type +// Helper classes to handle implicit conversion for passing marray types -template