From b57b04f0af0e9fe9ff1f7bb26d537ba982c7bde5 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Mon, 6 Feb 2023 10:49:34 +0100 Subject: [PATCH 1/8] introduction of marray's complex specialization --- include/sycl_ext_complex.hpp | 438 +++++++++++++++++++++++++++++++++++ 1 file changed, 438 insertions(+) diff --git a/include/sycl_ext_complex.hpp b/include/sycl_ext_complex.hpp index 36603fa..f6844bc 100644 --- a/include/sycl_ext_complex.hpp +++ b/include/sycl_ext_complex.hpp @@ -256,8 +256,12 @@ template complex tanh (const complex&); #define _SYCL_EXT_CPLX_FAST_MATH #endif +#define _SYCL_BEGIN_NAMESPACE namespace sycl { +#define _SYCL_END_NAMESPACE } + #define _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD namespace _SYCL_CPLX_NAMESPACE { #define _SYCL_EXT_CPLX_END_NAMESPACE_STD } + #define _SYCL_EXT_CPLX_INLINE_VISIBILITY \ [[gnu::always_inline]] [[clang::always_inline]] inline @@ -365,6 +369,10 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool isinf(const T a) { } } // namespace cplex::detail +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX IMPLEMENTATION +//////////////////////////////////////////////////////////////////////////////// + template class complex; template @@ -1232,6 +1240,436 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> tan(const complex<_Tp> &__x) { _SYCL_EXT_CPLX_END_NAMESPACE_STD +//////////////////////////////////////////////////////////////////////////////// +// MARRAY IMPLEMENTATION +//////////////////////////////////////////////////////////////////////////////// + +_SYCL_BEGIN_NAMESPACE + +// marray of complex class specialisation +template +class marray, NumElements> { +private: + using DataT = sycl::ext::cplx::complex; + +public: + using value_type = DataT; + using reference = DataT &; + using const_reference = const DataT &; + using iterator = DataT *; + using const_iterator = const DataT *; + +private: + value_type MData[NumElements]; + +public: + constexpr marray() : MData{} {}; + + explicit constexpr marray(const DataT &arg) { + for (size_t i = 0; i < NumElements; ++i) + MData[i] = arg; + } + + template + constexpr marray(const ArgTN &... args) : MData{args...} {}; + + constexpr marray(const marray &rhs) = default; + constexpr marray(marray &&rhs) = default; + + // Available only when: NumElements == 1 + template > + operator DataT() const { + return MData[0]; + } + + static constexpr std::size_t size() noexcept { return NumElements; } + + marray real() const { + sycl::marray rtn; + + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = MData[i].real(); + } + + return rtn; + } + + marray imag() const { + sycl::marray rtn; + + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = MData[i].imag(); + } + + return rtn; + } + + // subscript operator + reference operator[](std::size_t index) { return MData[index]; } + const_reference operator[](std::size_t index) const { return MData[index]; } + + marray &operator=(const marray &rhs) = default; + marray &operator=(const DataT &rhs) { + for (std::size_t i = 0; i < NumElements; ++i) + MData[i] = rhs; + + return *this; + } + + // iterator functions + iterator begin() { return MData; } + const_iterator begin() const { return MData; } + + iterator end() { return MData + NumElements; } + const_iterator end() const { return MData + NumElements; } + + // OP is: +, -, *, / +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs[i]; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, const DataT &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const DataT &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs op rhs[i]; \ + \ + return rtn; \ + } + + OP(+) + OP(-) + OP(*) + OP(/) + +#undef OP + + // OP is: % + friend marray operator%(const marray &lhs, const marray &rhs) = delete; + friend marray operator%(const marray &lhs, const DataT &rhs) = delete; + friend marray operator%(const DataT &lhs, const marray &rhs) = delete; + + // OP is: +=, -=, *=, /= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs[i]; \ + \ + return lhs; \ + } \ + \ + friend marray &operator op(marray &lhs, const DataT &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs; \ + \ + return lhs; \ + } \ + friend marray &operator op(DataT &lhs, const marray &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs; \ + \ + return lhs; \ + } + + OP(+=) + OP(-=) + OP(*=) + OP(/=) + +#undef OP + + // OP is: %= + friend marray &operator%=(marray &lhs, const marray &rhs) = delete; + friend marray &operator%=(marray &lhs, const DataT &rhs) = delete; + friend marray &operator%=(DataT &lhs, const marray &rhs) = delete; + +// OP is: ++, -- +#define OP(op) \ + friend marray operator op(marray &lhs, int) = delete; \ + friend marray &operator op(marray &rhs) = delete; + + OP(++) + OP(--) + +#undef OP + +// OP is: unary +, unary - +#define OP(op) \ + friend marray operator op( \ + const marray &rhs) { \ + marray rtn; \ + \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = op rhs[i]; \ + } \ + \ + return rtn; \ + } + + OP(+) + OP(-) + +#undef OP + +// OP is: &, |, ^ +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, const DataT &rhs) = delete; + + OP(&) + OP(|) + OP(^) + +#undef OP + +// OP is: &=, |=, ^= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ + friend marray &operator op(marray &lhs, const DataT &rhs) = delete; \ + friend marray &operator op(DataT &lhs, const marray &rhs) = delete; + + OP(&=) + OP(|=) + OP(^=) + +#undef OP + +// OP is: &&, || +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, \ + const DataT &rhs) = delete; \ + friend marray operator op(const DataT &lhs, \ + const marray &rhs) = delete; + + OP(&&) + OP(||) + +#undef OP + +// OP is: <<, >> +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, const DataT &rhs) = delete; \ + friend marray operator op(const DataT &lhs, const marray &rhs) = delete; + + OP(<<) + OP(>>) + +#undef OP + +// OP is: <<=, >>= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ + friend marray &operator op(marray &lhs, const DataT &rhs) = delete; + + OP(<<=) + OP(>>=) + +#undef OP + + // OP is: ==, != +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs[i]; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, \ + const DataT &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const DataT &lhs, \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs op rhs[i]; \ + \ + return rtn; \ + } + + OP(==) + OP(!=) + +#undef OP + + // OP is: <, >, <=, >= +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, \ + const DataT &rhs) = delete; \ + friend marray operator op(const DataT &lhs, \ + const marray &rhs) = delete; + + OP(<); + OP(>); + OP(<=); + OP(>=); + +#undef OP + + friend marray operator~(const marray &v) = delete; + + friend marray operator!(const marray &v) = delete; +}; + +_SYCL_END_NAMESPACE + +_SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD + +// Math marray overloads + +#define MATH_OP_ONE_PARAM(math_func, rtn_type, arg_type) \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i]); \ + \ + return rtn; \ + } + +MATH_OP_ONE_PARAM(abs, T, complex); +MATH_OP_ONE_PARAM(acos, complex, complex); +MATH_OP_ONE_PARAM(asin, complex, complex); +MATH_OP_ONE_PARAM(atan, complex, complex); +MATH_OP_ONE_PARAM(acosh, complex, complex); +MATH_OP_ONE_PARAM(asinh, complex, complex); +MATH_OP_ONE_PARAM(atanh, complex, complex); +MATH_OP_ONE_PARAM(arg, T, complex); +MATH_OP_ONE_PARAM(conj, complex, complex); +MATH_OP_ONE_PARAM(cos, complex, complex); +MATH_OP_ONE_PARAM(cosh, complex, complex); +MATH_OP_ONE_PARAM(exp, complex, complex); +MATH_OP_ONE_PARAM(log, complex, complex); +MATH_OP_ONE_PARAM(log10, complex, complex); +MATH_OP_ONE_PARAM(norm, T, complex); +MATH_OP_ONE_PARAM(proj, complex, complex); +MATH_OP_ONE_PARAM(proj, complex, T); +MATH_OP_ONE_PARAM(sin, complex, complex); +MATH_OP_ONE_PARAM(sinh, complex, complex); +MATH_OP_ONE_PARAM(sqrt, complex, complex); +MATH_OP_ONE_PARAM(tan, complex, complex); +MATH_OP_ONE_PARAM(tanh, complex, complex); + +#undef MATH_OP_ONE_PARAM + +#define MATH_OP_TWO_PARAM(math_func, rtn_type, arg_type1, arg_type2) \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x, \ + const sycl::marray &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i], y[i]); \ + \ + return rtn; \ + } \ + \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x, \ + const arg_type2 &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i], y); \ + \ + return rtn; \ + } \ + \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const arg_type1 &x, \ + const sycl::marray &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = math_func(x, y[i]); \ + \ + return rtn; \ + } + +MATH_OP_TWO_PARAM(pow, complex, complex, T); +MATH_OP_TWO_PARAM(pow, complex, complex, complex); +MATH_OP_TWO_PARAM(pow, complex, T, complex); + +#undef MATH_OP_TWO_PARAM + +// Special definition as polar requires default argument + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const sycl::marray &rho, + const sycl::marray &theta) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho[i], theta[i]); + + return rtn; +} + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const sycl::marray &rho, const T &theta = 0) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho[i], theta); + + return rtn; +} + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const T &rho, const sycl::marray &theta) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho, theta[i]); + + return rtn; +} + +_SYCL_EXT_CPLX_END_NAMESPACE_STD + +#undef _SYCL_BEGIN_NAMESPACE +#undef _SYCL_END_NAMESPACE + #undef _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD #undef _SYCL_EXT_CPLX_END_NAMESPACE_STD #undef _SYCL_EXT_CPLX_INLINE_VISIBILITY From 390cda622f8b3247da818e82258ebac5eab5b379 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Mon, 6 Feb 2023 10:54:19 +0100 Subject: [PATCH 2/8] update complex tests with marray test cases --- tests/abs_complex.cpp | 63 +++++++++++ tests/acos_complex.cpp | 122 ++++++++++++++++++++ tests/acosh_complex.cpp | 125 ++++++++++++++++++++- tests/arg_complex.cpp | 63 +++++++++++ tests/asin_complex.cpp | 121 ++++++++++++++++++++ tests/asinh_complex.cpp | 124 ++++++++++++++++++++- tests/atan_complex.cpp | 123 +++++++++++++++++++- tests/atanh_complex.cpp | 124 ++++++++++++++++++++- tests/conj_complex.cpp | 64 +++++++++++ tests/cos_complex.cpp | 64 +++++++++++ tests/cosh_complex.cpp | 66 ++++++++++- tests/exp_complex.cpp | 64 +++++++++++ tests/log10_complex.cpp | 64 +++++++++++ tests/log_complex.cpp | 64 +++++++++++ tests/norm_complex.cpp | 59 ++++++++++ tests/polar_complex.cpp | 52 +++++++++ tests/pow_complex.cpp | 241 ++++++++++++++++++++++++++++++++++++++++ tests/proj_complex.cpp | 118 ++++++++++++++++++++ tests/sin_complex.cpp | 64 +++++++++++ tests/sinh_complex.cpp | 64 +++++++++++ tests/sqrt_complex.cpp | 64 +++++++++++ tests/tan_complex.cpp | 64 +++++++++++ tests/tanh_complex.cpp | 66 ++++++++++- tests/test_helper.hpp | 56 ++++++++++ 24 files changed, 2093 insertions(+), 6 deletions(-) diff --git a/tests/abs_complex.cpp b/tests/abs_complex.cpp index 11243e5..8eb91a0 100644 --- a/tests/abs_complex.cpp +++ b/tests/abs_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex abs", "[abs]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,62 @@ TEMPLATE_TEST_CASE("Test complex abs", "[abs]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex abs", "[abs]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::abs(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::abs(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::abs(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/acos_complex.cpp b/tests/acos_complex.cpp index 6987ffe..f7bf666 100644 --- a/tests/acos_complex.cpp +++ b/tests/acos_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex acos", "[acos]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,3 +60,121 @@ TEMPLATE_TEST_CASE("Test complex acos", "[acos]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test( + sycl::queue &Q, const sycl::marray, NumElements> &std_in, + const sycl::marray, NumElements> &cplx_input, + bool is_error_checking) { + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::acos(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::acos(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::cos(sycl::ext::cplx::acos(cplx_input)); + }).wait(); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 4); + } + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::acos(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::cos(sycl::ext::cplx::acos(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex acos (check error: false)", + "[acos]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 2.5}, + std::complex{4.0, -4.0}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex acos (check error: true)", "[acos]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, true); +} diff --git a/tests/acosh_complex.cpp b/tests/acosh_complex.cpp index 72796e1..bc0a9e3 100644 --- a/tests/acosh_complex.cpp +++ b/tests/acosh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex acosh", "[acosh]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,4 +60,123 @@ TEMPLATE_TEST_CASE("Test complex acosh", "[acosh]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test( + sycl::queue &Q, const sycl::marray, NumElements> &std_in, + const sycl::marray, NumElements> &cplx_input, + bool is_error_checking) { + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::acosh(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::acosh(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::cosh(sycl::ext::cplx::acosh(cplx_input)); + }).wait(); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 4); + } + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::acosh(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::cosh(sycl::ext::cplx::acosh(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex acosh (check error: false)", + "[acosh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 2.5}, + std::complex{4.0, -4.0}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex acosh (check error: true)", + "[acosh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, true); +} diff --git a/tests/arg_complex.cpp b/tests/arg_complex.cpp index fc3c190..09a6202 100644 --- a/tests/arg_complex.cpp +++ b/tests/arg_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex arg cmplx", "[arg]", double, float, sycl::half) { using T = TestType; @@ -78,3 +82,62 @@ TEMPLATE_TEST_CASE("Test complex arg deci", "[arg]", (std::pair), sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex arg", "[arg]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::arg(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::arg(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::arg(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/asin_complex.cpp b/tests/asin_complex.cpp index 9898819..6c161b0 100644 --- a/tests/asin_complex.cpp +++ b/tests/asin_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex asin", "[asin]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,3 +60,120 @@ TEMPLATE_TEST_CASE("Test complex asin", "[asin]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test( + sycl::queue &Q, const sycl::marray, NumElements> &std_in, + const sycl::marray, NumElements> &cplx_input, + bool is_error_checking) { + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::asin(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::asin(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::sin(sycl::ext::cplx::asin(cplx_input)); + }).wait(); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 6); + } + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::asin(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::sin(sycl::ext::cplx::asin(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 6); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex asin (check error: false)", + "[asin]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 2.5}, + std::complex{4.0, -4.0}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex asin (check error: true)", "[asin]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, true); +} diff --git a/tests/asinh_complex.cpp b/tests/asinh_complex.cpp index f638709..adfd8a9 100644 --- a/tests/asinh_complex.cpp +++ b/tests/asinh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex asinh", "[asinh]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,4 +60,122 @@ TEMPLATE_TEST_CASE("Test complex asinh", "[asinh]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test( + sycl::queue &Q, const sycl::marray, NumElements> &std_in, + const sycl::marray, NumElements> &cplx_input, + bool is_error_checking) { + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::asinh(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::asinh(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::sinh(sycl::ext::cplx::asinh(cplx_input)); + }).wait(); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 4); + } + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::asinh(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::sinh(sycl::ext::cplx::asinh(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 3); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex asinh (check error: false)", + "[asinh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 2.5}, + std::complex{4.0, -4.0}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex asinh (check error: true)", + "[asinh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, true); +} diff --git a/tests/atan_complex.cpp b/tests/atan_complex.cpp index 693827a..b528fcb 100644 --- a/tests/atan_complex.cpp +++ b/tests/atan_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex atan", "[atan]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -55,4 +59,121 @@ TEMPLATE_TEST_CASE("Test complex atan", "[atan]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test( + sycl::queue &Q, const sycl::marray, NumElements> &std_in, + const sycl::marray, NumElements> &cplx_input, + bool is_error_checking) { + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::atan(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::atan(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::tan(sycl::ext::cplx::atan(cplx_input)); + }).wait(); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + } + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::atan(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::tan(sycl::ext::cplx::atan(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex atan (check error: false)", + "[atan]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 2.5}, + std::complex{4.0, -4.0}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex atan (check error: true)", "[atan]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, true); +} diff --git a/tests/atanh_complex.cpp b/tests/atanh_complex.cpp index 03a8536..a7906b1 100644 --- a/tests/atanh_complex.cpp +++ b/tests/atanh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex atanh", "[atanh]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -56,4 +60,122 @@ TEMPLATE_TEST_CASE("Test complex atanh", "[atanh]", double, float, sycl::half) { check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +template +auto test( + sycl::queue &Q, const sycl::marray, NumElements> &std_in, + const sycl::marray, NumElements> &cplx_input, + bool is_error_checking) { + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + if (is_error_checking) { + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::atanh(std_in[i]); + } else { + // Need to manually copy to handle as for for halfs, std_in is of value + // type float and std_out is of value type half + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std_in[i]; + } + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + if (is_error_checking) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::atanh(cplx_input); + }).wait(); + } else { + Q.single_task([=]() { + *cplx_out = + sycl::ext::cplx::tanh(sycl::ext::cplx::atanh(cplx_input)); + }).wait(); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + } + + // Check cplx::complex output from host + if (is_error_checking) { + *cplx_out = sycl::ext::cplx::atanh(cplx_input); + } else { + *cplx_out = sycl::ext::cplx::tanh(sycl::ext::cplx::atanh(cplx_input)); + } + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 2); + + sycl::free(cplx_out, Q); +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex atanh (check error: false)", + "[atanh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 2.5}, + std::complex{4.0, -4.0}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, false); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex atanh (check error: true)", + "[atanh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + // sycl::half cases are emulated with float for std::complex class (std_in) + using X = typename std::conditional::value, float, + T>::type; + test(Q, std_in, cplx_input, true); +} diff --git a/tests/conj_complex.cpp b/tests/conj_complex.cpp index bbfcece..fdf79a5 100644 --- a/tests/conj_complex.cpp +++ b/tests/conj_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex conj cmplx", "[conj]", double, float, sycl::half) { using T = TestType; @@ -78,3 +82,63 @@ TEMPLATE_TEST_CASE("Test complex conj deci", "[conj]", sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex conj", "[conj]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::conj(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::conj(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::conj(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/cos_complex.cpp b/tests/cos_complex.cpp index c5c5d55..33d24d5 100644 --- a/tests/cos_complex.cpp +++ b/tests/cos_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex cos", "[cos]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,63 @@ TEMPLATE_TEST_CASE("Test complex cos", "[cos]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex cos", "[cos]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::cos(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::cos(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::cos(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/cosh_complex.cpp b/tests/cosh_complex.cpp index a4a7ee9..c44417b 100644 --- a/tests/cosh_complex.cpp +++ b/tests/cosh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex cosh", "[cosh]", double, float, sycl::half) { using T = TestType; @@ -37,4 +41,64 @@ TEMPLATE_TEST_CASE("Test complex cosh", "[cosh]", double, float, sycl::half) { check_results(cplx_out[0], std_out); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex cosh", "[cosh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::cosh(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::cosh(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::cosh(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/exp_complex.cpp b/tests/exp_complex.cpp index 19bea14..751e33c 100644 --- a/tests/exp_complex.cpp +++ b/tests/exp_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex exp", "[exp]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,63 @@ TEMPLATE_TEST_CASE("Test complex exp", "[exp]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex exp", "[exp]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::exp(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::exp(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::exp(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/log10_complex.cpp b/tests/log10_complex.cpp index 4affaa2..6260111 100644 --- a/tests/log10_complex.cpp +++ b/tests/log10_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex log10", "[log10]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,63 @@ TEMPLATE_TEST_CASE("Test complex log10", "[log10]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex log10", "[log10]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::log10(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::log10(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::log10(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/log_complex.cpp b/tests/log_complex.cpp index 8907838..04ab884 100644 --- a/tests/log_complex.cpp +++ b/tests/log_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex log", "[log]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,63 @@ TEMPLATE_TEST_CASE("Test complex log", "[log]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex log", "[log]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::log(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::log(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::log(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/norm_complex.cpp b/tests/norm_complex.cpp index ea2941b..e671531 100644 --- a/tests/norm_complex.cpp +++ b/tests/norm_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex norm cmplx", "[norm]", double, float, sycl::half) { using T = TestType; @@ -77,3 +81,58 @@ TEMPLATE_TEST_CASE("Test complex norm deci", "[norm]", sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex norm", "[norm]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::norm(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::norm(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::norm(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/polar_complex.cpp b/tests/polar_complex.cpp index 4d2198f..ed7eca1 100644 --- a/tests/polar_complex.cpp +++ b/tests/polar_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex polar", "[polar]", double, float, sycl::half) { using T = TestType; using std::make_tuple; @@ -34,3 +38,51 @@ TEMPLATE_TEST_CASE("Test complex polar", "[polar]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex polar", "[polar]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 4), (float, 4), (sycl::half, 4)) { + sycl::queue Q; + + // Test cases + const auto rho = GENERATE(sycl::marray{ + 1.0, + 4.42, + 3, + 3.14, + }); + const auto theta = GENERATE(sycl::marray{ + 1.0, + 2.02, + 3.5, + -3.14, + }); + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::polar(rho[i], theta[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::polar(rho, theta); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::polar(rho, theta); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/pow_complex.cpp b/tests/pow_complex.cpp index 6e329a9..f7c6596 100644 --- a/tests/pow_complex.cpp +++ b/tests/pow_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex pow cplx-cplx overload", "[pow]", double, float, sycl::half) { using T = TestType; @@ -190,3 +194,240 @@ TEMPLATE_TEST_CASE("Test complex pow deci-cplx overload", "[pow]", double, sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex pow cplx-cplx overload", "[pow]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + /* sycl::half cases are emulated with float for std::complex class (std_in) */ + using X = typename std::conditional::value, float, + T>::type; + + // std::complex test cases + const auto std_in1 = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + sycl::marray, NumElements> std_in2; + for (std::size_t i = 0; i < NumElements; i++) { + std_in2[i] = std::complex{std_in1[i].real(), std_in1[i].imag()}; + } + + // sycl::complex test cases + const auto cplx_input1 = + GENERATE(sycl::marray, NumElements>{ + sycl::ext::cplx::complex{1.0, 1.0}, + sycl::ext::cplx::complex{4.42, 2.02}, + sycl::ext::cplx::complex{-3, 3.5}, + sycl::ext::cplx::complex{4.0, -4.0}, + sycl::ext::cplx::complex{2.02, inf_val}, + sycl::ext::cplx::complex{inf_val, 4.42}, + sycl::ext::cplx::complex{inf_val, nan_val}, + sycl::ext::cplx::complex{2.02, 4.42}, + sycl::ext::cplx::complex{nan_val, nan_val}, + sycl::ext::cplx::complex{nan_val, nan_val}, + sycl::ext::cplx::complex{inf_val, inf_val}, + sycl::ext::cplx::complex{nan_val, nan_val}, + sycl::ext::cplx::complex{inf_val, inf_val}, + sycl::ext::cplx::complex{nan_val, nan_val}, + }); + sycl::marray, NumElements> cplx_input2; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input2[i] = sycl::ext::cplx::complex{cplx_input1[i].real(), + cplx_input1[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::pow(std_in1[i], std_in2[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + }).wait(); + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 3); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex pow cplx-deci overload", "[pow]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in1 = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + const auto std_in2 = GENERATE(init_deci(sycl::marray{ + 1.0, + 4.42, + -3, + 4.0, + 2.02, + inf_val, + inf_val, + 2.02, + nan_val, + nan_val, + inf_val, + nan_val, + inf_val, + nan_val, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input1; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input1[i] = + sycl::ext::cplx::complex{std_in1[i].real(), std_in1[i].imag()}; + } + sycl::marray cplx_input2; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input2[i] = std_in2[i]; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::pow(std_in1[i], std_in2[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + }).wait(); + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 3); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex pow deci-cplx overload", "[pow]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in1 = GENERATE(init_deci(sycl::marray{ + 1.0, + 4.42, + -3, + 4.0, + 2.02, + inf_val, + inf_val, + 2.02, + nan_val, + nan_val, + inf_val, + nan_val, + inf_val, + nan_val, + })); + const auto std_in2 = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray cplx_input1; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input1[i] = std_in1[i]; + } + sycl::marray, NumElements> cplx_input2; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input2[i] = + sycl::ext::cplx::complex{std_in2[i].real(), std_in2[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::pow(std_in1[i], std_in2[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + }).wait(); + + check_results(*cplx_out, std_out, /*tol_multiplier*/ 3); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/proj_complex.cpp b/tests/proj_complex.cpp index ba40254..54d4037 100644 --- a/tests/proj_complex.cpp +++ b/tests/proj_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex proj cmplx", "[proj]", double, float, sycl::half) { using T = TestType; @@ -78,3 +82,117 @@ TEMPLATE_TEST_CASE("Test complex proj deci", "[proj]", sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex proj cplx overload", "[proj]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::proj(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::proj(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::proj(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex proj deci overload", "[proj]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = GENERATE(init_deci(sycl::marray{ + 1.0, + 4.42, + -3, + 4.0, + 2.02, + inf_val, + inf_val, + 2.02, + nan_val, + nan_val, + inf_val, + nan_val, + inf_val, + nan_val, + })); + + // sycl::complex test cases + sycl::marray cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = std_in[i]; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::proj(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::proj(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::proj(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/sin_complex.cpp b/tests/sin_complex.cpp index b6c8c38..057ae06 100644 --- a/tests/sin_complex.cpp +++ b/tests/sin_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex sin", "[sin]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,63 @@ TEMPLATE_TEST_CASE("Test complex sin", "[sin]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex sin", "[sin]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::sin(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::sin(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::sin(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/sinh_complex.cpp b/tests/sinh_complex.cpp index a11e1c2..6aa40fe 100644 --- a/tests/sinh_complex.cpp +++ b/tests/sinh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex sinh", "[sinh]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,63 @@ TEMPLATE_TEST_CASE("Test complex sinh", "[sinh]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex sinh", "[sinh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::sinh(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::sinh(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::sinh(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/sqrt_complex.cpp b/tests/sqrt_complex.cpp index 32601bb..38a6116 100644 --- a/tests/sqrt_complex.cpp +++ b/tests/sqrt_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex sqrt", "[sqrt]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,63 @@ TEMPLATE_TEST_CASE("Test complex sqrt", "[sqrt]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex sqrt", "[sqrt]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::sqrt(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::sqrt(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::sqrt(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/tan_complex.cpp b/tests/tan_complex.cpp index cc46157..23df8a2 100644 --- a/tests/tan_complex.cpp +++ b/tests/tan_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex tan", "[tan]", double, float, sycl::half) { using T = TestType; @@ -38,3 +42,63 @@ TEMPLATE_TEST_CASE("Test complex tan", "[tan]", double, float, sycl::half) { sycl::free(cplx_out, Q); } + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex tan", "[tan]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::tan(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::tan(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::tan(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/tanh_complex.cpp b/tests/tanh_complex.cpp index bfe934d..21691e0 100644 --- a/tests/tanh_complex.cpp +++ b/tests/tanh_complex.cpp @@ -1,5 +1,9 @@ #include "test_helper.hpp" +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX TESTS +//////////////////////////////////////////////////////////////////////////////// + TEMPLATE_TEST_CASE("Test complex tanh", "[tanh]", double, float, sycl::half) { using T = TestType; @@ -37,4 +41,64 @@ TEMPLATE_TEST_CASE("Test complex tanh", "[tanh]", double, float, sycl::half) { check_results(cplx_out[0], std_out); sycl::free(cplx_out, Q); -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex tanh", "[tanh]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 14), (float, 14), (sycl::half, 14)) { + sycl::queue Q; + + // std::complex test cases + const auto std_in = + GENERATE(init_std_complex(sycl::marray, NumElements>{ + std::complex{1.0, 1.0}, + std::complex{4.42, 2.02}, + std::complex{-3, 3.5}, + std::complex{4.0, -4.0}, + std::complex{2.02, inf_val}, + std::complex{inf_val, 4.42}, + std::complex{inf_val, nan_val}, + std::complex{2.02, 4.42}, + std::complex{nan_val, nan_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + std::complex{inf_val, inf_val}, + std::complex{nan_val, nan_val}, + })); + + // sycl::complex test cases + sycl::marray, NumElements> cplx_input; + for (std::size_t i = 0; i < NumElements; ++i) { + cplx_input[i] = + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; + } + + sycl::marray, NumElements> std_out{}; + auto *cplx_out = sycl::malloc_shared< + sycl::marray, NumElements>>(1, Q); + + // Get std::complex output + for (std::size_t i = 0; i < NumElements; ++i) + std_out[i] = std::tanh(std_in[i]); + + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + *cplx_out = sycl::ext::cplx::tanh(cplx_input); + }).wait(); + + check_results(*cplx_out, std_out); + } + + // Check cplx::complex output from host + *cplx_out = sycl::ext::cplx::tanh(cplx_input); + + check_results(*cplx_out, std_out); + + sycl::free(cplx_out, Q); +} diff --git a/tests/test_helper.hpp b/tests/test_helper.hpp index fdf3d82..773e980 100644 --- a/tests/test_helper.hpp +++ b/tests/test_helper.hpp @@ -119,12 +119,49 @@ template <> auto constexpr init_std_complex(cmplx c) { return detail::trunc_float(std::complex(c.re, c.im)); } +template +auto constexpr init_std_complex( + sycl::marray, NumElements> input) { + return input; +} + +template +auto constexpr init_std_complex( + sycl::marray, NumElements> input) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < rtn.size(); ++i) { + rtn[i] = detail::trunc_float( + std::complex(input[i].real(), input[i].imag())); + } + return rtn; +} + template auto constexpr init_deci(T_in re) { return re; } template <> auto constexpr init_deci(sycl::half re) { return static_cast(re); } +template +auto constexpr init_deci(sycl::marray re) { + sycl::marray rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = static_cast(re[i]); + return rtn; +} + +// Helper to change marray of std::complex value type + +template +auto constexpr convert_marray(sycl::marray, NumElements> c) { + sycl::marray, NumElements> rtn; + + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = c[i]; + + return rtn; +} + // Helpers for comparing SyclCPLX and standard c++ results template @@ -139,3 +176,22 @@ void check_results(T output, T reference, int tol_multiplier = 1) { CHECK(detail::almost_equal(output, reference, tol_multiplier * SYCL_CPLX_TOL_ULP)); } + +template +void check_results( + sycl::marray, NumElements> output, + sycl::marray, NumElements> reference, + int tol_multiplier = 1) { + for (std::size_t i = 0; i < NumElements; ++i) { + check_results(output[i], reference[i], tol_multiplier); + } +} + +template +void check_results(sycl::marray output, + sycl::marray reference, + int tol_multiplier = 1) { + for (std::size_t i = 0; i < NumElements; ++i) { + check_results(output[i], reference[i], tol_multiplier); + } +} From 16fa809f6ba3cf0827234ee77a3da275d850c814 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Mon, 6 Feb 2023 10:55:21 +0100 Subject: [PATCH 3/8] introduce new tests for marray --- tests/test_marray_complex_getters.cpp | 77 ++++++ tests/test_marray_complex_types.cpp | 353 +++++++++++++++++++++++++ tests/test_operator_marray_complex.cpp | 226 ++++++++++++++++ 3 files changed, 656 insertions(+) create mode 100644 tests/test_marray_complex_getters.cpp create mode 100644 tests/test_marray_complex_types.cpp create mode 100644 tests/test_operator_marray_complex.cpp diff --git a/tests/test_marray_complex_getters.cpp b/tests/test_marray_complex_getters.cpp new file mode 100644 index 0000000..256afc9 --- /dev/null +++ b/tests/test_marray_complex_getters.cpp @@ -0,0 +1,77 @@ +#include "test_helper.hpp" + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex real component marray", "[getter]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 8), (float, 8), (sycl::half, 8)) { + sycl::queue Q; + + // Test cases + const auto init = GENERATE(sycl::marray{ + 0.0, + 1.0, + 4.42, + -0.0, + -1.0, + -4.42, + nan_val, + inf_val, + }); + sycl::marray, NumElements> input; + for (std::size_t i = 0; i < NumElements; ++i) { + input[i] = sycl::ext::cplx::complex{init[i], (T)0}; + } + + auto *out = sycl::malloc_shared>(1, Q); + + if (is_type_supported(Q)) { + Q.single_task([=]() { *out = input.real(); }).wait(); + + check_results(*out, init); + } + + *out = input.real(); + + check_results(*out, init); + + sycl::free(out, Q); +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex imag component marray", "[getter]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 8), (float, 8), (sycl::half, 8)) { + sycl::queue Q; + + // Test cases + const auto init = GENERATE(sycl::marray{ + 0.0, + 1.0, + 4.42, + -0.0, + -1.0, + -4.42, + nan_val, + inf_val, + }); + sycl::marray, NumElements> input; + for (std::size_t i = 0; i < NumElements; ++i) { + input[i] = sycl::ext::cplx::complex{(T)0, init[i]}; + } + + auto *out = sycl::malloc_shared>(1, Q); + + if (is_type_supported(Q)) { + Q.single_task([=]() { *out = input.imag(); }).wait(); + + check_results(*out, init); + } + + *out = input.imag(); + + check_results(*out, init); + + sycl::free(out, Q); +} diff --git a/tests/test_marray_complex_types.cpp b/tests/test_marray_complex_types.cpp new file mode 100644 index 0000000..4c833f8 --- /dev/null +++ b/tests/test_marray_complex_types.cpp @@ -0,0 +1,353 @@ +#include "test_helper.hpp" + +using namespace sycl::ext::cplx; + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +// Define math function tests +#define TEST_MATH_FUNC_TYPE(func) \ + template \ + auto test##_##func##_##types(sycl::queue &Q) { \ + /* Check host code */ \ + static_assert(std::is_same_v< \ + sycl::marray, NumElements>, \ + decltype(func(sycl::marray, NumElements>()))>); \ + \ + /* Check device code */ \ + if (is_type_supported(Q)) { \ + Q.single_task([=]() { \ + static_assert( \ + std::is_same_v, NumElements>, \ + decltype(func( \ + sycl::marray, NumElements>()))>); \ + }).wait(); \ + } \ + } + +TEST_MATH_FUNC_TYPE(acos) +TEST_MATH_FUNC_TYPE(asin) +TEST_MATH_FUNC_TYPE(atan) +TEST_MATH_FUNC_TYPE(acosh) +TEST_MATH_FUNC_TYPE(asinh) +TEST_MATH_FUNC_TYPE(atanh) +TEST_MATH_FUNC_TYPE(conj) +TEST_MATH_FUNC_TYPE(cos) +TEST_MATH_FUNC_TYPE(cosh) +TEST_MATH_FUNC_TYPE(exp) +TEST_MATH_FUNC_TYPE(log) +TEST_MATH_FUNC_TYPE(log10) +TEST_MATH_FUNC_TYPE(proj) +TEST_MATH_FUNC_TYPE(sin) +TEST_MATH_FUNC_TYPE(sinh) +TEST_MATH_FUNC_TYPE(sqrt) +TEST_MATH_FUNC_TYPE(tan) +TEST_MATH_FUNC_TYPE(tanh) +#undef TEST_MATH_FUNC_TYPE + +// Define math operations tests +#define TEST_MATH_OP_TYPE(op_name, op) \ + template \ + auto test##_##op_name##_##types(sycl::queue &Q) { \ + /* Check host code */ \ + static_assert( \ + std::is_same_v< \ + sycl::marray, NumElements>, \ + decltype(std::declval, NumElements>>() \ + op std::declval< \ + sycl::marray, NumElements>>())>); \ + \ + static_assert( \ + std::is_same_v< \ + sycl::marray, NumElements>, \ + decltype(std::declval, NumElements>>() \ + op std::declval>())>); \ + \ + static_assert( \ + std::is_same_v< \ + sycl::marray, NumElements>, \ + decltype(std::declval>() op std:: \ + declval, NumElements>>())>); \ + \ + /* Check device code */ \ + if (is_type_supported(Q)) { \ + Q.single_task([=]() { \ + static_assert( \ + std::is_same_v, NumElements>, \ + decltype(std::declval< \ + sycl::marray, NumElements>>() \ + op std::declval, NumElements>>())>); \ + \ + static_assert( \ + std::is_same_v< \ + sycl::marray, NumElements>, \ + decltype(std::declval< \ + sycl::marray, NumElements>>() op \ + std::declval>())>); \ + \ + static_assert( \ + std::is_same_v< \ + sycl::marray, NumElements>, \ + decltype(std::declval>() \ + op std::declval< \ + sycl::marray, NumElements>>())>); \ + }).wait(); \ + } \ + } + +TEST_MATH_OP_TYPE(add, +) +TEST_MATH_OP_TYPE(sub, -) +TEST_MATH_OP_TYPE(mul, *) +TEST_MATH_OP_TYPE(div, /) +#undef TEST_MATH_FUNC_TYPE + +// Define math operations tests +#define TEST_LOGIC_OP_TYPE(op_name, op) \ + template \ + auto test##_##op_name##_##types(sycl::queue &Q) { \ + /* Check host code */ \ + static_assert( \ + std::is_same_v< \ + sycl::marray, \ + decltype(std::declval, NumElements>>() \ + op std::declval< \ + sycl::marray, NumElements>>())>); \ + \ + static_assert( \ + std::is_same_v< \ + sycl::marray, \ + decltype(std::declval, NumElements>>() \ + op std::declval>())>); \ + \ + static_assert( \ + std::is_same_v< \ + sycl::marray, \ + decltype(std::declval>() op std:: \ + declval, NumElements>>())>); \ + \ + /* Check device code */ \ + if (is_type_supported(Q)) { \ + Q.single_task([=]() { \ + static_assert( \ + std::is_same_v, \ + decltype(std::declval< \ + sycl::marray, NumElements>>() \ + op std::declval, NumElements>>())>); \ + \ + static_assert( \ + std::is_same_v< \ + sycl::marray, \ + decltype(std::declval< \ + sycl::marray, NumElements>>() op \ + std::declval>())>); \ + \ + static_assert( \ + std::is_same_v< \ + sycl::marray, \ + decltype(std::declval>() \ + op std::declval< \ + sycl::marray, NumElements>>())>); \ + }).wait(); \ + } \ + } + +TEST_LOGIC_OP_TYPE(eq, ==) +TEST_LOGIC_OP_TYPE(inv_eq, !=) +#undef TEST_LOGIC_OP_TYPE + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +TEMPLATE_TEST_CASE_SIG("Test marray complex abs function return types", + "[marray types]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + static_assert( + std::is_same_v, + decltype(abs(sycl::marray, NumElements>()))>); + + /* Check device code */ + if (is_type_supported(Q)) { + Q.single_task([=]() { + static_assert(std::is_same_v< + sycl::marray, + decltype(abs(sycl::marray, NumElements>()))>); + }).wait(); + } +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex polar function return types", + "[marray types]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + /* Check host code */ + static_assert( + std::is_same_v, NumElements>, + decltype(polar(sycl::marray()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(polar(sycl::marray(), + sycl::marray()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(polar(sycl::marray(), T()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(polar(T(), sycl::marray()))>); + + /* Check device code */ + if (is_type_supported(Q)) { + Q.single_task([=]() { + static_assert( + std::is_same_v, NumElements>, + decltype(polar(sycl::marray()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(polar(sycl::marray(), + sycl::marray()))>); + static_assert(std::is_same_v, NumElements>, + decltype(polar( + sycl::marray(), T()))>); + static_assert(std::is_same_v, NumElements>, + decltype(polar( + T(), sycl::marray()))>); + }).wait(); + } +} + +TEMPLATE_TEST_CASE_SIG("Test marray complex pow function return types", + "[marray types]", + ((typename T, std::size_t NumElements), T, NumElements), + (double, 10), (float, 10), (sycl::half, 10)) { + sycl::queue Q; + + /* Check host code */ + // complex-deci + static_assert( + std::is_same_v, NumElements>, + decltype(pow(sycl::marray, NumElements>(), + sycl::marray()))>); + static_assert(std::is_same_v< + sycl::marray, NumElements>, + decltype(pow(sycl::marray, NumElements>(), T()))>); + static_assert(std::is_same_v, NumElements>, + decltype(pow(complex(), + sycl::marray()))>); + + // complex-complex + static_assert( + std::is_same_v, NumElements>, + decltype(pow(sycl::marray, NumElements>(), + sycl::marray, NumElements>()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(pow(sycl::marray, NumElements>(), + complex()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(pow(complex(), + sycl::marray, NumElements>()))>); + + // deci-complx + static_assert( + std::is_same_v, NumElements>, + decltype(pow(sycl::marray(), + sycl::marray, NumElements>()))>); + static_assert(std::is_same_v, NumElements>, + decltype(pow(sycl::marray(), + complex()))>); + static_assert(std::is_same_v< + sycl::marray, NumElements>, + decltype(pow(T(), sycl::marray, NumElements>()))>); + + /* Check device code */ + if (is_type_supported(Q)) { + Q.single_task([=]() { + // complex-deci + static_assert( + std::is_same_v, NumElements>, + decltype(pow(sycl::marray, NumElements>(), + sycl::marray()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(pow(sycl::marray, NumElements>(), + T()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(pow(complex(), + sycl::marray()))>); + + // complex-complex + static_assert(std::is_same_v< + sycl::marray, NumElements>, + decltype(pow(sycl::marray, NumElements>(), + sycl::marray, NumElements>()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(pow(sycl::marray, NumElements>(), + complex()))>); + static_assert(std::is_same_v< + sycl::marray, NumElements>, + decltype(pow(complex(), + sycl::marray, NumElements>()))>); + + // deci-complx + static_assert(std::is_same_v< + sycl::marray, NumElements>, + decltype(pow(sycl::marray(), + sycl::marray, NumElements>()))>); + static_assert(std::is_same_v, NumElements>, + decltype(pow(sycl::marray(), + complex()))>); + static_assert( + std::is_same_v, NumElements>, + decltype(pow( + T(), sycl::marray, NumElements>()))>); + }).wait(); + } +} + +#define TEST(func) \ + TEMPLATE_TEST_CASE_SIG( \ + "Test marray complex " #func " function return types", "[marray types]", \ + ((typename T, std::size_t NumElements), T, NumElements), (double, 10), \ + (float, 10), (sycl::half, 10)) { \ + sycl::queue Q; \ + test##_##func##_##types(Q); \ + } + +TEST(acos) +TEST(asin) +TEST(atan) +TEST(acosh) +TEST(asinh) +TEST(atanh) +TEST(conj) +TEST(cos) +TEST(cosh) +TEST(exp) +TEST(log) +TEST(log10) +TEST(proj) +TEST(sin) +TEST(sinh) +TEST(sqrt) +TEST(tan) +TEST(tanh) + +TEST(add) +TEST(sub) +TEST(mul) +TEST(div) + +TEST(eq) +TEST(inv_eq) +#undef TEST diff --git a/tests/test_operator_marray_complex.cpp b/tests/test_operator_marray_complex.cpp new file mode 100644 index 0000000..42b0d60 --- /dev/null +++ b/tests/test_operator_marray_complex.cpp @@ -0,0 +1,226 @@ +#include "test_helper.hpp" + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS'S UTILITIES +//////////////////////////////////////////////////////////////////////////////// + +#define test_op(name, op) \ + template \ + auto test##_##name( \ + sycl::queue &Q, \ + const sycl::marray, NumElements> &std_in1, \ + const sycl::marray, NumElements> &std_in2, \ + const sycl::marray, NumElements> \ + &cplx_input1, \ + const sycl::marray, NumElements> \ + &cplx_input2) { \ + \ + sycl::marray, NumElements> std_out{}; \ + auto *cplx_out = sycl::malloc_shared< \ + sycl::marray, NumElements>>(1, Q); \ + \ + /* Get std::complex output */ \ + for (std::size_t i = 0; i < NumElements; ++i) \ + std_out[i] = std_in1[i] op std_in2[i]; \ + \ + /* Check cplx::complex output from device */ \ + if (is_type_supported(Q)) { \ + Q.single_task([=]() { *cplx_out = cplx_input1 op cplx_input2; }).wait(); \ + check_results(*cplx_out, std_out); \ + } \ + \ + /* Check cplx::complex output from host */ \ + *cplx_out = cplx_input1 op cplx_input2; \ + \ + check_results(*cplx_out, std_out); \ + \ + sycl::free(cplx_out, Q); \ + } + +test_op(add, +); +test_op(sub, -); +test_op(mul, /); +test_op(div, *); + +#undef test_op + +#define test_op_assign(name, op_assign) \ + template \ + auto test##_##name( \ + sycl::queue &Q, \ + const sycl::marray, NumElements> &std_in, \ + sycl::marray, NumElements> &std_inout, \ + const sycl::marray, NumElements> \ + &cplx_input1, \ + sycl::marray, NumElements> &cplx_input2) { \ + \ + auto *cplx_inout = sycl::malloc_shared< \ + sycl::marray, NumElements>>(1, Q); \ + *cplx_inout = cplx_input2; \ + \ + /* Get std::complex output */ \ + for (std::size_t i = 0; i < NumElements; ++i) \ + std_inout[i] op_assign std_in[i]; \ + \ + /* Check cplx::complex output from device */ \ + if (is_type_supported(Q)) { \ + Q.single_task([=]() { *cplx_inout op_assign cplx_input1; }).wait(); \ + check_results(*cplx_inout, convert_marray(std_inout)); \ + } \ + \ + *cplx_inout = cplx_input2; \ + \ + /* Check cplx::complex output from host */ \ + *cplx_inout op_assign cplx_input1; \ + \ + check_results(*cplx_inout, convert_marray(std_inout)); \ + \ + sycl::free(cplx_inout, Q); \ + } + +test_op_assign(add_assign, +=); +test_op_assign(sub_assign, -=); +test_op_assign(mul_assign, /=); +test_op_assign(div_assign, *=); + +#undef test_op_assign + +//////////////////////////////////////////////////////////////////////////////// +// MARRAY TESTS +//////////////////////////////////////////////////////////////////////////////// + +#define TEST(func) \ + TEMPLATE_TEST_CASE_SIG( \ + "Test marray complex operator " #func, "[" #func "]", \ + ((typename T, std::size_t NumElements), T, NumElements), (double, 14), \ + (float, 14), (sycl::half, 14)) { \ + sycl::queue Q; \ + \ + /* sycl::half cases are emulated with float for std::complex class \ + * (std_in) */ \ + using X = typename std::conditional::value, \ + float, T>::type; \ + \ + /* std::complex test cases */ \ + const auto std_in1 = \ + GENERATE(init_std_complex(sycl::marray, NumElements>{ \ + std::complex{1.0, 1.0}, \ + std::complex{4.42, 2.02}, \ + std::complex{-3, 3.5}, \ + std::complex{4.0, -4.0}, \ + std::complex{2.02, inf_val}, \ + std::complex{inf_val, 4.42}, \ + std::complex{inf_val, nan_val}, \ + std::complex{2.02, 4.42}, \ + std::complex{nan_val, nan_val}, \ + std::complex{nan_val, nan_val}, \ + std::complex{inf_val, inf_val}, \ + std::complex{nan_val, nan_val}, \ + std::complex{inf_val, inf_val}, \ + std::complex{nan_val, nan_val}, \ + })); \ + sycl::marray, NumElements> std_in2; \ + for (std::size_t i = 0; i < NumElements; i++) { \ + std_in2[i] = std::complex{std_in1[i].real(), std_in1[i].imag()}; \ + } \ + \ + /* sycl::complex test cases */ \ + const auto cplx_input1 = \ + GENERATE(sycl::marray, NumElements>{ \ + sycl::ext::cplx::complex{1.0, 1.0}, \ + sycl::ext::cplx::complex{4.42, 2.02}, \ + sycl::ext::cplx::complex{-3, 3.5}, \ + sycl::ext::cplx::complex{4.0, -4.0}, \ + sycl::ext::cplx::complex{2.02, inf_val}, \ + sycl::ext::cplx::complex{inf_val, 4.42}, \ + sycl::ext::cplx::complex{inf_val, nan_val}, \ + sycl::ext::cplx::complex{2.02, 4.42}, \ + sycl::ext::cplx::complex{nan_val, nan_val}, \ + sycl::ext::cplx::complex{nan_val, nan_val}, \ + sycl::ext::cplx::complex{inf_val, inf_val}, \ + sycl::ext::cplx::complex{nan_val, nan_val}, \ + sycl::ext::cplx::complex{inf_val, inf_val}, \ + sycl::ext::cplx::complex{nan_val, nan_val}, \ + }); \ + sycl::marray, NumElements> cplx_input2; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + cplx_input2[i] = sycl::ext::cplx::complex{cplx_input1[i].real(), \ + cplx_input1[i].imag()}; \ + } \ + \ + test##_##func(Q, std_in1, std_in2, cplx_input1, \ + cplx_input2); \ + } + +TEST(add) +TEST(sub) +TEST(mul) +TEST(div) + +TEST(add_assign) +TEST(sub_assign) +TEST(mul_assign) +TEST(div_assign) + +#undef TEST + +#define test_unary_op(test_name, label, op) \ + TEMPLATE_TEST_CASE_SIG( \ + test_name, label, \ + ((typename T, std::size_t NumElements), T, NumElements), (double, 14), \ + (float, 14), (sycl::half, 14)) { \ + sycl::queue Q; \ + \ + /* std::complex test cases */ \ + const auto std_in = \ + GENERATE(init_std_complex(sycl::marray, NumElements>{ \ + std::complex{1.0, 1.0}, \ + std::complex{4.42, 2.02}, \ + std::complex{-3, 3.5}, \ + std::complex{4.0, -4.0}, \ + std::complex{2.02, inf_val}, \ + std::complex{inf_val, 4.42}, \ + std::complex{inf_val, nan_val}, \ + std::complex{2.02, 4.42}, \ + std::complex{nan_val, nan_val}, \ + std::complex{nan_val, nan_val}, \ + std::complex{inf_val, inf_val}, \ + std::complex{nan_val, nan_val}, \ + std::complex{inf_val, inf_val}, \ + std::complex{nan_val, nan_val}, \ + })); \ + \ + /* sycl::complex test cases */ \ + sycl::marray, NumElements> cplx_input; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + cplx_input[i] = \ + sycl::ext::cplx::complex{std_in[i].real(), std_in[i].imag()}; \ + } \ + \ + sycl::marray, NumElements> std_out{}; \ + auto *cplx_out = sycl::malloc_shared< \ + sycl::marray, NumElements>>(1, Q); \ + \ + /* Get std::complex output */ \ + for (std::size_t i = 0; i < NumElements; ++i) \ + std_out[i] = op std_in[i]; \ + \ + /* Check cplx::complex output from device */ \ + if (is_type_supported(Q)) { \ + Q.single_task([=]() { *cplx_out = op cplx_input; }).wait(); \ + \ + check_results(*cplx_out, std_out); \ + } \ + \ + /* Check cplx::complex output from host */ \ + *cplx_out = op cplx_input; \ + \ + check_results(*cplx_out, std_out); \ + \ + sycl::free(cplx_out, Q); \ + } + +test_unary_op("Test marray complex addition unary operator", "[add]", +); +test_unary_op("Test marray complex subtraction unary operator", "[sub]", -); + +#undef test_unary_op From 262f227692b81ef577ac19e434e4f4f5cc9c0d83 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Mon, 6 Feb 2023 11:17:53 +0100 Subject: [PATCH 4/8] Fix clang-format CI error --- tests/test_marray_complex_types.cpp | 46 +++++++++++++++-------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/test_marray_complex_types.cpp b/tests/test_marray_complex_types.cpp index 4c833f8..6441c11 100644 --- a/tests/test_marray_complex_types.cpp +++ b/tests/test_marray_complex_types.cpp @@ -54,9 +54,9 @@ TEST_MATH_FUNC_TYPE(tanh) static_assert( \ std::is_same_v< \ sycl::marray, NumElements>, \ - decltype(std::declval, NumElements>>() \ - op std::declval< \ - sycl::marray, NumElements>>())>); \ + decltype( \ + std::declval, NumElements>>() op \ + std::declval, NumElements>>())>); \ \ static_assert( \ std::is_same_v< \ @@ -74,18 +74,19 @@ TEST_MATH_FUNC_TYPE(tanh) if (is_type_supported(Q)) { \ Q.single_task([=]() { \ static_assert( \ - std::is_same_v, NumElements>, \ - decltype(std::declval< \ - sycl::marray, NumElements>>() \ - op std::declval, NumElements>>())>); \ + std::is_same_v< \ + sycl::marray, NumElements>, \ + decltype( \ + std::declval, NumElements>>() \ + op std::declval< \ + sycl::marray, NumElements>>())>); \ \ static_assert( \ std::is_same_v< \ sycl::marray, NumElements>, \ - decltype(std::declval< \ - sycl::marray, NumElements>>() op \ - std::declval>())>); \ + decltype( \ + std::declval, NumElements>>() \ + op std::declval>())>); \ \ static_assert( \ std::is_same_v< \ @@ -111,9 +112,9 @@ TEST_MATH_OP_TYPE(div, /) static_assert( \ std::is_same_v< \ sycl::marray, \ - decltype(std::declval, NumElements>>() \ - op std::declval< \ - sycl::marray, NumElements>>())>); \ + decltype( \ + std::declval, NumElements>>() op \ + std::declval, NumElements>>())>); \ \ static_assert( \ std::is_same_v< \ @@ -131,18 +132,19 @@ TEST_MATH_OP_TYPE(div, /) if (is_type_supported(Q)) { \ Q.single_task([=]() { \ static_assert( \ - std::is_same_v, \ - decltype(std::declval< \ - sycl::marray, NumElements>>() \ - op std::declval, NumElements>>())>); \ + std::is_same_v< \ + sycl::marray, \ + decltype( \ + std::declval, NumElements>>() \ + op std::declval< \ + sycl::marray, NumElements>>())>); \ \ static_assert( \ std::is_same_v< \ sycl::marray, \ - decltype(std::declval< \ - sycl::marray, NumElements>>() op \ - std::declval>())>); \ + decltype( \ + std::declval, NumElements>>() \ + op std::declval>())>); \ \ static_assert( \ std::is_same_v< \ From 83b1ee95cadfc9629f0475ea48ec1508f2510c11 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Mon, 6 Feb 2023 11:22:41 +0100 Subject: [PATCH 5/8] Fix clang-format-14 CI error --- include/sycl_ext_complex.hpp | 2 +- tests/test_marray_complex_types.cpp | 46 ++++++++++++++--------------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/include/sycl_ext_complex.hpp b/include/sycl_ext_complex.hpp index f6844bc..ce3ebab 100644 --- a/include/sycl_ext_complex.hpp +++ b/include/sycl_ext_complex.hpp @@ -1271,7 +1271,7 @@ class marray, NumElements> { } template - constexpr marray(const ArgTN &... args) : MData{args...} {}; + constexpr marray(const ArgTN &...args) : MData{args...} {}; constexpr marray(const marray &rhs) = default; constexpr marray(marray &&rhs) = default; diff --git a/tests/test_marray_complex_types.cpp b/tests/test_marray_complex_types.cpp index 6441c11..4c833f8 100644 --- a/tests/test_marray_complex_types.cpp +++ b/tests/test_marray_complex_types.cpp @@ -54,9 +54,9 @@ TEST_MATH_FUNC_TYPE(tanh) static_assert( \ std::is_same_v< \ sycl::marray, NumElements>, \ - decltype( \ - std::declval, NumElements>>() op \ - std::declval, NumElements>>())>); \ + decltype(std::declval, NumElements>>() \ + op std::declval< \ + sycl::marray, NumElements>>())>); \ \ static_assert( \ std::is_same_v< \ @@ -74,19 +74,18 @@ TEST_MATH_FUNC_TYPE(tanh) if (is_type_supported(Q)) { \ Q.single_task([=]() { \ static_assert( \ - std::is_same_v< \ - sycl::marray, NumElements>, \ - decltype( \ - std::declval, NumElements>>() \ - op std::declval< \ - sycl::marray, NumElements>>())>); \ + std::is_same_v, NumElements>, \ + decltype(std::declval< \ + sycl::marray, NumElements>>() \ + op std::declval, NumElements>>())>); \ \ static_assert( \ std::is_same_v< \ sycl::marray, NumElements>, \ - decltype( \ - std::declval, NumElements>>() \ - op std::declval>())>); \ + decltype(std::declval< \ + sycl::marray, NumElements>>() op \ + std::declval>())>); \ \ static_assert( \ std::is_same_v< \ @@ -112,9 +111,9 @@ TEST_MATH_OP_TYPE(div, /) static_assert( \ std::is_same_v< \ sycl::marray, \ - decltype( \ - std::declval, NumElements>>() op \ - std::declval, NumElements>>())>); \ + decltype(std::declval, NumElements>>() \ + op std::declval< \ + sycl::marray, NumElements>>())>); \ \ static_assert( \ std::is_same_v< \ @@ -132,19 +131,18 @@ TEST_MATH_OP_TYPE(div, /) if (is_type_supported(Q)) { \ Q.single_task([=]() { \ static_assert( \ - std::is_same_v< \ - sycl::marray, \ - decltype( \ - std::declval, NumElements>>() \ - op std::declval< \ - sycl::marray, NumElements>>())>); \ + std::is_same_v, \ + decltype(std::declval< \ + sycl::marray, NumElements>>() \ + op std::declval, NumElements>>())>); \ \ static_assert( \ std::is_same_v< \ sycl::marray, \ - decltype( \ - std::declval, NumElements>>() \ - op std::declval>())>); \ + decltype(std::declval< \ + sycl::marray, NumElements>>() op \ + std::declval>())>); \ \ static_assert( \ std::is_same_v< \ From 22b23cbfd13fcedd3d906db28ad78b72ac96e698 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Fri, 24 Feb 2023 17:29:44 +0100 Subject: [PATCH 6/8] rename 'index' variable to 'i' to match the overall interface convention --- include/sycl_ext_complex.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/sycl_ext_complex.hpp b/include/sycl_ext_complex.hpp index ce3ebab..ae3f46b 100644 --- a/include/sycl_ext_complex.hpp +++ b/include/sycl_ext_complex.hpp @@ -1305,8 +1305,8 @@ class marray, NumElements> { } // subscript operator - reference operator[](std::size_t index) { return MData[index]; } - const_reference operator[](std::size_t index) const { return MData[index]; } + reference operator[](std::size_t i) { return MData[i]; } + const_reference operator[](std::size_t i) const { return MData[i]; } marray &operator=(const marray &rhs) = default; marray &operator=(const DataT &rhs) { From fc69738700bc37226305d1c3291f1ec8c6b74e24 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Wed, 8 Mar 2023 11:37:19 +0100 Subject: [PATCH 7/8] rename DataT to ComplexDataT --- include/sycl_ext_complex.hpp | 81 +++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/include/sycl_ext_complex.hpp b/include/sycl_ext_complex.hpp index ae3f46b..3330bad 100644 --- a/include/sycl_ext_complex.hpp +++ b/include/sycl_ext_complex.hpp @@ -297,7 +297,9 @@ template struct __numeric_type { static const bool value = !std::is_same::value; }; -template <> struct __numeric_type { static const bool value = true; }; +template <> struct __numeric_type { + static const bool value = true; +}; template ::value &&__numeric_type<_A2>::value @@ -1250,14 +1252,14 @@ _SYCL_BEGIN_NAMESPACE template class marray, NumElements> { private: - using DataT = sycl::ext::cplx::complex; + using ComplexDataT = sycl::ext::cplx::complex; public: - using value_type = DataT; - using reference = DataT &; - using const_reference = const DataT &; - using iterator = DataT *; - using const_iterator = const DataT *; + using value_type = ComplexDataT; + using reference = ComplexDataT &; + using const_reference = const ComplexDataT &; + using iterator = ComplexDataT *; + using const_iterator = const ComplexDataT *; private: value_type MData[NumElements]; @@ -1265,7 +1267,7 @@ class marray, NumElements> { public: constexpr marray() : MData{} {}; - explicit constexpr marray(const DataT &arg) { + explicit constexpr marray(const ComplexDataT &arg) { for (size_t i = 0; i < NumElements; ++i) MData[i] = arg; } @@ -1273,12 +1275,12 @@ class marray, NumElements> { template constexpr marray(const ArgTN &...args) : MData{args...} {}; - constexpr marray(const marray &rhs) = default; - constexpr marray(marray &&rhs) = default; + constexpr marray(const marray &rhs) = default; + constexpr marray(marray &&rhs) = default; // Available only when: NumElements == 1 template > - operator DataT() const { + operator ComplexDataT() const { return MData[0]; } @@ -1308,8 +1310,8 @@ class marray, NumElements> { reference operator[](std::size_t i) { return MData[i]; } const_reference operator[](std::size_t i) const { return MData[i]; } - marray &operator=(const marray &rhs) = default; - marray &operator=(const DataT &rhs) { + marray &operator=(const marray &rhs) = default; + marray &operator=(const ComplexDataT &rhs) { for (std::size_t i = 0; i < NumElements; ++i) MData[i] = rhs; @@ -1333,7 +1335,7 @@ class marray, NumElements> { return rtn; \ } \ \ - friend marray operator op(const marray &lhs, const DataT &rhs) { \ + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) { \ marray rtn; \ for (std::size_t i = 0; i < NumElements; ++i) \ rtn[i] = lhs[i] op rhs; \ @@ -1341,7 +1343,7 @@ class marray, NumElements> { return rtn; \ } \ \ - friend marray operator op(const DataT &lhs, const marray &rhs) { \ + friend marray operator op(const ComplexDataT &lhs, const marray &rhs) { \ marray rtn; \ for (std::size_t i = 0; i < NumElements; ++i) \ rtn[i] = lhs op rhs[i]; \ @@ -1358,8 +1360,8 @@ class marray, NumElements> { // OP is: % friend marray operator%(const marray &lhs, const marray &rhs) = delete; - friend marray operator%(const marray &lhs, const DataT &rhs) = delete; - friend marray operator%(const DataT &lhs, const marray &rhs) = delete; + friend marray operator%(const marray &lhs, const ComplexDataT &rhs) = delete; + friend marray operator%(const ComplexDataT &lhs, const marray &rhs) = delete; // OP is: +=, -=, *=, /= #define OP(op) \ @@ -1370,13 +1372,13 @@ class marray, NumElements> { return lhs; \ } \ \ - friend marray &operator op(marray &lhs, const DataT &rhs) { \ + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) { \ for (std::size_t i = 0; i < NumElements; ++i) \ lhs[i] op rhs; \ \ return lhs; \ } \ - friend marray &operator op(DataT &lhs, const marray &rhs) { \ + friend marray &operator op(ComplexDataT &lhs, const marray &rhs) { \ for (std::size_t i = 0; i < NumElements; ++i) \ lhs[i] op rhs; \ \ @@ -1392,8 +1394,8 @@ class marray, NumElements> { // OP is: %= friend marray &operator%=(marray &lhs, const marray &rhs) = delete; - friend marray &operator%=(marray &lhs, const DataT &rhs) = delete; - friend marray &operator%=(DataT &lhs, const marray &rhs) = delete; + friend marray &operator%=(marray &lhs, const ComplexDataT &rhs) = delete; + friend marray &operator%=(ComplexDataT &lhs, const marray &rhs) = delete; // OP is: ++, -- #define OP(op) \ @@ -1407,9 +1409,9 @@ class marray, NumElements> { // OP is: unary +, unary - #define OP(op) \ - friend marray operator op( \ - const marray &rhs) { \ - marray rtn; \ + friend marray operator op( \ + const marray &rhs) { \ + marray rtn; \ \ for (std::size_t i = 0; i < NumElements; ++i) { \ rtn[i] = op rhs[i]; \ @@ -1426,7 +1428,8 @@ class marray, NumElements> { // OP is: &, |, ^ #define OP(op) \ friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ - friend marray operator op(const marray &lhs, const DataT &rhs) = delete; + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) = \ + delete; OP(&) OP(|) @@ -1437,8 +1440,8 @@ class marray, NumElements> { // OP is: &=, |=, ^= #define OP(op) \ friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ - friend marray &operator op(marray &lhs, const DataT &rhs) = delete; \ - friend marray &operator op(DataT &lhs, const marray &rhs) = delete; + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray &operator op(ComplexDataT &lhs, const marray &rhs) = delete; OP(&=) OP(|=) @@ -1450,9 +1453,9 @@ class marray, NumElements> { #define OP(op) \ friend marray operator op(const marray &lhs, \ const marray &rhs) = delete; \ - friend marray operator op(const marray &lhs, \ - const DataT &rhs) = delete; \ - friend marray operator op(const DataT &lhs, \ + friend marray operator op( \ + const marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray operator op(const ComplexDataT &lhs, \ const marray &rhs) = delete; OP(&&) @@ -1463,8 +1466,10 @@ class marray, NumElements> { // OP is: <<, >> #define OP(op) \ friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ - friend marray operator op(const marray &lhs, const DataT &rhs) = delete; \ - friend marray operator op(const DataT &lhs, const marray &rhs) = delete; + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) = \ + delete; \ + friend marray operator op(const ComplexDataT &lhs, const marray &rhs) = \ + delete; OP(<<) OP(>>) @@ -1474,7 +1479,7 @@ class marray, NumElements> { // OP is: <<=, >>= #define OP(op) \ friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ - friend marray &operator op(marray &lhs, const DataT &rhs) = delete; + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) = delete; OP(<<=) OP(>>=) @@ -1493,7 +1498,7 @@ class marray, NumElements> { } \ \ friend marray operator op(const marray &lhs, \ - const DataT &rhs) { \ + const ComplexDataT &rhs) { \ marray rtn; \ for (std::size_t i = 0; i < NumElements; ++i) \ rtn[i] = lhs[i] op rhs; \ @@ -1501,7 +1506,7 @@ class marray, NumElements> { return rtn; \ } \ \ - friend marray operator op(const DataT &lhs, \ + friend marray operator op(const ComplexDataT &lhs, \ const marray &rhs) { \ marray rtn; \ for (std::size_t i = 0; i < NumElements; ++i) \ @@ -1519,9 +1524,9 @@ class marray, NumElements> { #define OP(op) \ friend marray operator op(const marray &lhs, \ const marray &rhs) = delete; \ - friend marray operator op(const marray &lhs, \ - const DataT &rhs) = delete; \ - friend marray operator op(const DataT &lhs, \ + friend marray operator op( \ + const marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray operator op(const ComplexDataT &lhs, \ const marray &rhs) = delete; OP(<); From 78007db53670d9a25da429f2add0f9ec9fb688b7 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Wed, 8 Mar 2023 12:06:03 +0100 Subject: [PATCH 8/8] restructure define and handle hypsycl and sycl namespace for marray --- include/sycl_ext_complex.hpp | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/include/sycl_ext_complex.hpp b/include/sycl_ext_complex.hpp index 3330bad..1e6b430 100644 --- a/include/sycl_ext_complex.hpp +++ b/include/sycl_ext_complex.hpp @@ -252,15 +252,23 @@ template complex tanh (const complex&); #endif #endif -#if defined(__FAST_MATH__) || defined(_M_FP_FAST) -#define _SYCL_EXT_CPLX_FAST_MATH +#define _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD namespace _SYCL_CPLX_NAMESPACE { +#define _SYCL_EXT_CPLX_END_NAMESPACE_STD } + +#ifndef _SYCL_MARRAY_NAMESPACE +#ifdef __HIPSYCL__ +#define _SYCL_MARRAY_NAMESPACE hipsycl::sycl +#else +#define _SYCL_MARRAY_NAMESPACE sycl +#endif #endif -#define _SYCL_BEGIN_NAMESPACE namespace sycl { -#define _SYCL_END_NAMESPACE } +#define _SYCL_MARRAY_BEGIN_NAMESPACE namespace _SYCL_MARRAY_NAMESPACE { +#define _SYCL_MARRAY_END_NAMESPACE } -#define _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD namespace _SYCL_CPLX_NAMESPACE { -#define _SYCL_EXT_CPLX_END_NAMESPACE_STD } +#if defined(__FAST_MATH__) || defined(_M_FP_FAST) +#define _SYCL_EXT_CPLX_FAST_MATH +#endif #define _SYCL_EXT_CPLX_INLINE_VISIBILITY \ [[gnu::always_inline]] [[clang::always_inline]] inline @@ -1246,7 +1254,7 @@ _SYCL_EXT_CPLX_END_NAMESPACE_STD // MARRAY IMPLEMENTATION //////////////////////////////////////////////////////////////////////////////// -_SYCL_BEGIN_NAMESPACE +_SYCL_MARRAY_BEGIN_NAMESPACE // marray of complex class specialisation template @@ -1541,7 +1549,7 @@ class marray, NumElements> { friend marray operator!(const marray &v) = delete; }; -_SYCL_END_NAMESPACE +_SYCL_MARRAY_END_NAMESPACE _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD @@ -1672,8 +1680,8 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY _SYCL_EXT_CPLX_END_NAMESPACE_STD -#undef _SYCL_BEGIN_NAMESPACE -#undef _SYCL_END_NAMESPACE +#undef _SYCL_MARRAY_BEGIN_NAMESPACE +#undef _SYCL_MARRAY_END_NAMESPACE #undef _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD #undef _SYCL_EXT_CPLX_END_NAMESPACE_STD