diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..31b091e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "vendor/Catch2"] + path = vendor/Catch2 + url = https://github.com/catchorg/Catch2 diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..177f523 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.15) +project(sandbox LANGUAGES CXX) + +# Setup Catch2 +if(NOT EXISTS "${PROJECT_SOURCE_DIR}/vendor/Catch2/CMakeLists.txt") + message(FATAL_ERROR "The git submodule vendor/Catch2 is missing.\nTry running `git submodule update --init`.") +endif() + +add_subdirectory(vendor/Catch2) + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") + +# Load catch2 cmake module +include(CTest) +include(Catch) + +add_subdirectory(tests) + diff --git a/README.md b/README.md index e32a1d2..8f4a16a 100644 --- a/README.md +++ b/README.md @@ -44,4 +44,13 @@ int main() { ## Tests -Crude tests for all C++ complex math functions are provided in `/tests/`. Just defined a correct `CXX` and then `make` , `make run` +Testing is implemented with Catch2 and CMake. Catch2 is added as a git submodule inside the `vendor` directory. + +Instructions to build and run tests for DPCPP are below: +``` +mkdir build +cd build +cmake -DCMAKE_CXX_COMPILER=$CXX_PATH -DCMAKE_CXX_FLAGS=-fsycl .. +make -j 8 +ctest +``` diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..86f9467 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,17 @@ +file(GLOB test_cases CONFIGURE_DEPENDS "*.cpp") + +foreach(test_file IN LISTS test_cases) + if(EXISTS "${test_file}") + get_filename_component(exe_name "${test_file}" NAME_WE) + + add_executable(${exe_name} ${test_file}) + target_include_directories(${exe_name} PUBLIC ../include/) + target_link_libraries(${exe_name} PRIVATE + Catch2::Catch2WithMain + ) + + catch_discover_tests(${exe_name}) + else() + message(FATAL_ERROR "No file named ${test_file}") + endif() +endforeach() diff --git a/tests/Makefile b/tests/Makefile deleted file mode 100644 index 319eb67..0000000 --- a/tests/Makefile +++ /dev/null @@ -1,26 +0,0 @@ - - -TIMEOUT = $(shell command -v timeout 2> /dev/null) -OVO_TIMEOUT ?= 10s -ifdef TIMEOUT - TIMEOUT = timeout -k 5s $(OVO_TIMEOUT) -endif - -SRC = $(wildcard *.cpp) -.PHONY: exe -exe: $(SRC:%.cpp=%.exe) - -pEXE = $(wildcard *.exe) -.PHONY: run -run: $(addprefix run_, $(basename $(pEXE))) - -%.exe: %.cpp - -$(TIMEOUT) $(CXX) $(CXXFLAGS) $(CURDIR)/$< -o $(CURDIR)/$@ - - -run_%: %.exe - -$(TIMEOUT) $(CURDIR)/$< - -.PHONY: clean -clean: - rm -f -- $(pEXE) diff --git a/tests/abs_complex.cpp b/tests/abs_complex.cpp index e57c301..11243e5 100644 --- a/tests/abs_complex.cpp +++ b/tests/abs_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_abs { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex abs", "[abs]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - T std_out{}; - auto *cplx_out = sycl::malloc_shared(1, Q); + T std_out{}; + auto *cplx_out = sycl::malloc_shared(1, Q); - // Get std::complex output - std_out = std::abs(std_in); + // Get std::complex output + std_out = std::abs(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::abs(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::abs(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::abs(cplx_input); - if (!test_passes) - std::cerr << "acos complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/acos_complex.cpp b/tests/acos_complex.cpp index babd107..95d7209 100644 --- a/tests/acos_complex.cpp +++ b/tests/acos_complex.cpp @@ -1,21 +1,37 @@ #include "test_helper.hpp" -template struct test_acos { - bool operator()(sycl::queue &Q, T init_re, T init_im, - bool is_error_checking = false) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex acos", "[acos]", double, float, sycl::half) { + using T = TestType; + using std::make_tuple; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + cmplx input; + bool is_error_checking; + + std::tie(input, is_error_checking) = GENERATE(table, bool>( + {make_tuple(cmplx{4.42, 2.02}, false), + make_tuple(cmplx{inf_val, 2.02}, true), + make_tuple(cmplx{4.42, inf_val}, true), + make_tuple(cmplx{inf_val, inf_val}, true), + make_tuple(cmplx{nan_val, 2.02}, true), + make_tuple(cmplx{4.42, nan_val}, true), + make_tuple(cmplx{nan_val, nan_val}, true), + make_tuple(cmplx{nan_val, inf_val}, true), + make_tuple(cmplx{inf_val, nan_val}, true)})); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{init_re, init_im}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{input.re, input.im}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - if (is_error_checking) - std_out = std::acos(std_in); + // Get std::complex output + if (is_error_checking) + std_out = std::acos(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { if (is_error_checking) { Q.single_task( [=]() { cplx_out[0] = sycl::ext::cplx::acos(cplx_input); }); @@ -26,47 +42,17 @@ template struct test_acos { }); } Q.wait(); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true, - /*tol_multiplier*/ 2); - - // Check cplx::complex output from host - if (is_error_checking) - cplx_out[0] = sycl::ext::cplx::acos(cplx_input); - else - cplx_out[0] = - sycl::ext::cplx::cos(sycl::ext::cplx::acos(cplx_input)); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false, - /*tol_multiplier*/ 2); - - sycl::free(cplx_out, Q); - - return pass; } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, true); - test_passes &= test_valid_types(Q, NAN, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, NAN, true); - test_passes &= test_valid_types(Q, NAN, NAN, true); + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); + // Check cplx::complex output from host + if (is_error_checking) + cplx_out[0] = sycl::ext::cplx::acos(cplx_input); + else + cplx_out[0] = sycl::ext::cplx::cos(sycl::ext::cplx::acos(cplx_input)); - if (!test_passes) - std::cerr << "acos complex test fails\n"; + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/acosh_complex.cpp b/tests/acosh_complex.cpp index 072827b..ed9eafd 100644 --- a/tests/acosh_complex.cpp +++ b/tests/acosh_complex.cpp @@ -1,21 +1,37 @@ #include "test_helper.hpp" -template struct test_acosh { - bool operator()(sycl::queue &Q, T init_re, T init_im, - bool is_error_checking = false) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex acosh", "[acosh]", double, float, sycl::half) { + using T = TestType; + using std::make_tuple; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + cmplx input; + bool is_error_checking; + + std::tie(input, is_error_checking) = GENERATE(table, bool>( + {make_tuple(cmplx{4.42, 2.02}, false), + make_tuple(cmplx{inf_val, 2.02}, true), + make_tuple(cmplx{4.42, inf_val}, true), + make_tuple(cmplx{inf_val, inf_val}, true), + make_tuple(cmplx{nan_val, 2.02}, true), + make_tuple(cmplx{4.42, nan_val}, true), + make_tuple(cmplx{nan_val, nan_val}, true), + make_tuple(cmplx{nan_val, inf_val}, true), + make_tuple(cmplx{inf_val, nan_val}, true)})); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{init_re, init_im}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{input.re, input.im}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - if (is_error_checking) - std_out = std::acosh(std_in); + // Get std::complex output + if (is_error_checking) + std_out = std::acosh(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { if (is_error_checking) { Q.single_task( [=]() { cplx_out[0] = sycl::ext::cplx::acosh(cplx_input); }); @@ -26,47 +42,18 @@ template struct test_acosh { }); } Q.wait(); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true, - /*tol_multiplier*/ 2); - - // Check cplx::complex output from host - if (is_error_checking) - cplx_out[0] = sycl::ext::cplx::acosh(cplx_input); - else - cplx_out[0] = - sycl::ext::cplx::cosh(sycl::ext::cplx::acosh(cplx_input)); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false, - /*tol_multiplier*/ 2); - - sycl::free(cplx_out, Q); - - return pass; } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, true); - test_passes &= test_valid_types(Q, NAN, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, NAN, true); - test_passes &= test_valid_types(Q, NAN, NAN, true); + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); + // Check cplx::complex output from host + if (is_error_checking) + cplx_out[0] = sycl::ext::cplx::acosh(cplx_input); + else + cplx_out[0] = + sycl::ext::cplx::cosh(sycl::ext::cplx::acosh(cplx_input)); - if (!test_passes) - std::cerr << "acosh complex test fails\n"; + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - return !test_passes; -} + sycl::free(cplx_out, Q); +} \ No newline at end of file diff --git a/tests/arg_complex.cpp b/tests/arg_complex.cpp index 9e43935..efb2a25 100644 --- a/tests/arg_complex.cpp +++ b/tests/arg_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_arg { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex arg", "[arg]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - T std_out{}; - auto *cplx_out = sycl::malloc_shared(1, Q); + T std_out{}; + auto *cplx_out = sycl::malloc_shared(1, Q); - // Get std::complex output - std_out = std::arg(std_in); + // Get std::complex output + std_out = std::arg(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::arg(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::arg(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::arg(cplx_input); - if (!test_passes) - std::cerr << "acos complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/asin_complex.cpp b/tests/asin_complex.cpp index 4fd0ae3..185d634 100644 --- a/tests/asin_complex.cpp +++ b/tests/asin_complex.cpp @@ -1,21 +1,37 @@ #include "test_helper.hpp" -template struct test_asin { - bool operator()(sycl::queue &Q, T init_re, T init_im, - bool is_error_checking = false) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex asin", "[asin]", double, float, sycl::half) { + using T = TestType; + using std::make_tuple; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + cmplx input; + bool is_error_checking; + + std::tie(input, is_error_checking) = GENERATE(table, bool>( + {make_tuple(cmplx{4.42, 2.02}, false), + make_tuple(cmplx{inf_val, 2.02}, true), + make_tuple(cmplx{4.42, inf_val}, true), + make_tuple(cmplx{inf_val, inf_val}, true), + make_tuple(cmplx{nan_val, 2.02}, true), + make_tuple(cmplx{4.42, nan_val}, true), + make_tuple(cmplx{nan_val, nan_val}, true), + make_tuple(cmplx{nan_val, inf_val}, true), + make_tuple(cmplx{inf_val, nan_val}, true)})); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{init_re, init_im}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{input.re, input.im}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - if (is_error_checking) - std_out = std::asin(std_in); + // Get std::complex output + if (is_error_checking) + std_out = std::asin(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { if (is_error_checking) { Q.single_task( [=]() { cplx_out[0] = sycl::ext::cplx::asin(cplx_input); }); @@ -26,47 +42,17 @@ template struct test_asin { }); } Q.wait(); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true, - /*tol_multiplier*/ 2); - - // Check cplx::complex output from host - if (is_error_checking) - cplx_out[0] = sycl::ext::cplx::asin(cplx_input); - else - cplx_out[0] = - sycl::ext::cplx::sin(sycl::ext::cplx::asin(cplx_input)); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false, - /*tol_multiplier*/ 2); - - sycl::free(cplx_out, Q); - - return pass; } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 0.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, true); - test_passes &= test_valid_types(Q, NAN, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, NAN, true); - test_passes &= test_valid_types(Q, NAN, NAN, true); + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); + // Check cplx::complex output from host + if (is_error_checking) + cplx_out[0] = sycl::ext::cplx::asin(cplx_input); + else + cplx_out[0] = sycl::ext::cplx::sin(sycl::ext::cplx::asin(cplx_input)); - if (!test_passes) - std::cerr << "asin complex test fails\n"; + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - return !test_passes; -} + sycl::free(cplx_out, Q); +} \ No newline at end of file diff --git a/tests/asinh_complex.cpp b/tests/asinh_complex.cpp index 82d29e7..cc1babc 100644 --- a/tests/asinh_complex.cpp +++ b/tests/asinh_complex.cpp @@ -1,21 +1,37 @@ #include "test_helper.hpp" -template struct test_asinh { - bool operator()(sycl::queue &Q, T init_re, T init_im, - bool is_error_checking = false) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex asinh", "[asinh]", double, float, sycl::half) { + using T = TestType; + using std::make_tuple; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + cmplx input; + bool is_error_checking; + + std::tie(input, is_error_checking) = GENERATE(table, bool>( + {make_tuple(cmplx{4.42, 2.02}, false), + make_tuple(cmplx{inf_val, 2.02}, true), + make_tuple(cmplx{4.42, inf_val}, true), + make_tuple(cmplx{inf_val, inf_val}, true), + make_tuple(cmplx{nan_val, 2.02}, true), + make_tuple(cmplx{4.42, nan_val}, true), + make_tuple(cmplx{nan_val, nan_val}, true), + make_tuple(cmplx{nan_val, inf_val}, true), + make_tuple(cmplx{inf_val, nan_val}, true)})); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{init_re, init_im}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{input.re, input.im}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - if (is_error_checking) - std_out = std::asinh(std_in); + // Get std::complex output + if (is_error_checking) + std_out = std::asinh(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { if (is_error_checking) { Q.single_task( [=]() { cplx_out[0] = sycl::ext::cplx::asinh(cplx_input); }); @@ -26,47 +42,18 @@ template struct test_asinh { }); } Q.wait(); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true, - /*tol_multiplier*/ 2); - - // Check cplx::complex output from host - if (is_error_checking) - cplx_out[0] = sycl::ext::cplx::asinh(cplx_input); - else - cplx_out[0] = - sycl::ext::cplx::sinh(sycl::ext::cplx::asinh(cplx_input)); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false, - /*tol_multiplier*/ 2); - - sycl::free(cplx_out, Q); - - return pass; } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 0.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, true); - test_passes &= test_valid_types(Q, NAN, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, NAN, true); - test_passes &= test_valid_types(Q, NAN, NAN, true); + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); + // Check cplx::complex output from host + if (is_error_checking) + cplx_out[0] = sycl::ext::cplx::asinh(cplx_input); + else + cplx_out[0] = + sycl::ext::cplx::sinh(sycl::ext::cplx::asinh(cplx_input)); - if (!test_passes) - std::cerr << "asinh complex test fails\n"; + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - return !test_passes; -} + sycl::free(cplx_out, Q); +} \ No newline at end of file diff --git a/tests/atan_complex.cpp b/tests/atan_complex.cpp index ce08477..e0283b1 100644 --- a/tests/atan_complex.cpp +++ b/tests/atan_complex.cpp @@ -1,21 +1,37 @@ #include "test_helper.hpp" -template struct test_atan { - bool operator()(sycl::queue &Q, T init_re, T init_im, - bool is_error_checking = false) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex atan", "[atan]", double, float, sycl::half) { + using T = TestType; + using std::make_tuple; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + cmplx input; + bool is_error_checking; + + std::tie(input, is_error_checking) = GENERATE(table, bool>( + {make_tuple(cmplx{4.42, 2.02}, false), + make_tuple(cmplx{inf_val, 2.02}, true), + make_tuple(cmplx{4.42, inf_val}, true), + make_tuple(cmplx{inf_val, inf_val}, true), + make_tuple(cmplx{nan_val, 2.02}, true), + make_tuple(cmplx{4.42, nan_val}, true), + make_tuple(cmplx{nan_val, nan_val}, true), + make_tuple(cmplx{nan_val, inf_val}, true), + make_tuple(cmplx{inf_val, nan_val}, true)})); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{init_re, init_im}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{input.re, input.im}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - if (is_error_checking) - std_out = std::atan(std_in); + // Get std::complex output + if (is_error_checking) + std_out = std::atan(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { if (is_error_checking) { Q.single_task( [=]() { cplx_out[0] = sycl::ext::cplx::atan(cplx_input); }); @@ -26,47 +42,17 @@ template struct test_atan { }); } Q.wait(); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true, - /*tol_multiplier*/ 2); - - // Check cplx::complex output from host - if (is_error_checking) - cplx_out[0] = sycl::ext::cplx::atan(cplx_input); - else - cplx_out[0] = - sycl::ext::cplx::tan(sycl::ext::cplx::atan(cplx_input)); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false, - /*tol_multiplier*/ 2); - - sycl::free(cplx_out, Q); - - return pass; } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 0.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, true); - test_passes &= test_valid_types(Q, NAN, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, NAN, true); - test_passes &= test_valid_types(Q, NAN, NAN, true); + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); + // Check cplx::complex output from host + if (is_error_checking) + cplx_out[0] = sycl::ext::cplx::atan(cplx_input); + else + cplx_out[0] = sycl::ext::cplx::tan(sycl::ext::cplx::atan(cplx_input)); - if (!test_passes) - std::cerr << "atan complex test fails\n"; + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - return !test_passes; -} + sycl::free(cplx_out, Q); +} \ No newline at end of file diff --git a/tests/atanh_complex.cpp b/tests/atanh_complex.cpp index c7a15f5..f20bb38 100644 --- a/tests/atanh_complex.cpp +++ b/tests/atanh_complex.cpp @@ -1,21 +1,37 @@ #include "test_helper.hpp" -template struct test_atanh { - bool operator()(sycl::queue &Q, T init_re, T init_im, - bool is_error_checking = false) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex atanh", "[atanh]", double, float, sycl::half) { + using T = TestType; + using std::make_tuple; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + cmplx input; + bool is_error_checking; + + std::tie(input, is_error_checking) = GENERATE(table, bool>( + {make_tuple(cmplx{4.42, 2.02}, false), + make_tuple(cmplx{inf_val, 2.02}, true), + make_tuple(cmplx{4.42, inf_val}, true), + make_tuple(cmplx{inf_val, inf_val}, true), + make_tuple(cmplx{nan_val, 2.02}, true), + make_tuple(cmplx{4.42, nan_val}, true), + make_tuple(cmplx{nan_val, nan_val}, true), + make_tuple(cmplx{nan_val, inf_val}, true), + make_tuple(cmplx{inf_val, nan_val}, true)})); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{init_re, init_im}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{input.re, input.im}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - if (is_error_checking) - std_out = std::atanh(std_in); + // Get std::complex output + if (is_error_checking) + std_out = std::atanh(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { if (is_error_checking) { Q.single_task( [=]() { cplx_out[0] = sycl::ext::cplx::atanh(cplx_input); }); @@ -26,47 +42,18 @@ template struct test_atanh { }); } Q.wait(); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true, - /*tol_multiplier*/ 2); - - // Check cplx::complex output from host - if (is_error_checking) - cplx_out[0] = sycl::ext::cplx::atanh(cplx_input); - else - cplx_out[0] = - sycl::ext::cplx::tanh(sycl::ext::cplx::atanh(cplx_input)); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false, - /*tol_multiplier*/ 2); - - sycl::free(cplx_out, Q); - - return pass; } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 0.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, true); - test_passes &= test_valid_types(Q, NAN, 2.02, true); - test_passes &= test_valid_types(Q, 4.42, NAN, true); - test_passes &= test_valid_types(Q, NAN, NAN, true); + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); - test_passes &= test_valid_types(Q, NAN, INFINITY, true); - test_passes &= test_valid_types(Q, INFINITY, NAN, true); + // Check cplx::complex output from host + if (is_error_checking) + cplx_out[0] = sycl::ext::cplx::atanh(cplx_input); + else + cplx_out[0] = + sycl::ext::cplx::tanh(sycl::ext::cplx::atanh(cplx_input)); - if (!test_passes) - std::cerr << "atanh complex test fails\n"; + check_results(cplx_out[0], std_out, /*tol_multiplier*/ 2); - return !test_passes; -} + sycl::free(cplx_out, Q); +} \ No newline at end of file diff --git a/tests/cos_complex.cpp b/tests/cos_complex.cpp index 09bebe2..c5c5d55 100644 --- a/tests/cos_complex.cpp +++ b/tests/cos_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_cos { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex cos", "[cos]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::cos(std_in); + // Get std::complex output + std_out = std::cos(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::cos(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::cos(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::cos(cplx_input); - if (!test_passes) - std::cerr << "cos complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/cosh_complex.cpp b/tests/cosh_complex.cpp index 0853ee0..a4a7ee9 100644 --- a/tests/cosh_complex.cpp +++ b/tests/cosh_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_cosh { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex cosh", "[cosh]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::cosh(std_in); + // Get std::complex output + std_out = std::cosh(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::cosh(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::cosh(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::cosh(cplx_input); - if (!test_passes) - std::cerr << "cosh complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; -} + sycl::free(cplx_out, Q); +} \ No newline at end of file diff --git a/tests/exp_complex.cpp b/tests/exp_complex.cpp index aecccb9..19bea14 100644 --- a/tests/exp_complex.cpp +++ b/tests/exp_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_exp { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex exp", "[exp]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::exp(std_in); + // Get std::complex output + std_out = std::exp(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::exp(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::exp(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::exp(cplx_input); - if (!test_passes) - std::cerr << "exp complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/log10_complex.cpp b/tests/log10_complex.cpp index 1f822fd..4affaa2 100644 --- a/tests/log10_complex.cpp +++ b/tests/log10_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_log10 { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex log10", "[log10]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::log10(std_in); + // Get std::complex output + std_out = std::log10(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::log10(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::log10(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::log10(cplx_input); - if (!test_passes) - std::cerr << "log10 complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/log_complex.cpp b/tests/log_complex.cpp index ed59a86..8907838 100644 --- a/tests/log_complex.cpp +++ b/tests/log_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_log { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex log", "[log]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::log(std_in); + // Get std::complex output + std_out = std::log(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::log(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::log(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::log(cplx_input); - if (!test_passes) - std::cerr << "log complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/norm_complex.cpp b/tests/norm_complex.cpp index 7b2570d..1778425 100644 --- a/tests/norm_complex.cpp +++ b/tests/norm_complex.cpp @@ -1,54 +1,39 @@ #include "test_helper.hpp" -template struct test_norm { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex norm", "[norm]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = + GENERATE(cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - T std_out{}; - auto *cplx_out = sycl::malloc_shared(1, Q); + T std_out{}; + auto *cplx_out = sycl::malloc_shared(1, Q); - // Get std::complex output - std_out = std::norm(std_in); + // Get std::complex output + std_out = std::norm(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::norm(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::norm(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -// Difference between libstdc++ and libc++ when NaN's and Inf values are -// combined. -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::norm(cplx_input); - if (!test_passes) - std::cerr << "acos complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/polar_complex.cpp b/tests/polar_complex.cpp index 37ce4ad..4d2198f 100644 --- a/tests/polar_complex.cpp +++ b/tests/polar_complex.cpp @@ -1,43 +1,36 @@ #include "test_helper.hpp" -template struct test_polar { - bool operator()(sycl::queue &Q, T init_rho, T init_theta) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex polar", "[polar]", double, float, sycl::half) { + using T = TestType; + using std::make_tuple; - auto *cplx_out = sycl::malloc_shared>(1, Q); - - // Get std::complex output - std::complex std_out = std::polar(init_rho, init_theta); - - // Check cplx::complex output from device - Q.single_task([=]() { - cplx_out[0] = sycl::ext::cplx::polar(init_rho, init_theta); - }).wait(); + sycl::queue Q; - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); + // Test cases + // Note: Output is undefined if rho is negative or Nan, or theta is Inf + T rho, theta; + std::tie(rho, theta) = GENERATE(table( + {make_tuple(4.42, 2.02), make_tuple(1, 3.14), make_tuple(1, -3.14)})); - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::polar(init_rho, init_theta); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); + // Get std::complex output + std_out = std::polar(rho, theta); - sycl::free(cplx_out, Q); + // Check cplx::complex output from device + if (is_type_supported(Q)) { + Q.single_task([=]() { + cplx_out[0] = sycl::ext::cplx::polar(rho, theta); + }).wait(); - return pass; + check_results(cplx_out[0], std_out); } -}; - -// Note: Output is undefined if rho is negative or Nan, or theta is Inf -int main() { - sycl::queue Q; - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - test_passes &= test_valid_types(Q, 1, 3.14); - test_passes &= test_valid_types(Q, 1, -3.14); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::polar(rho, theta); - if (!test_passes) - std::cerr << "acos complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/pow_complex.cpp b/tests/pow_complex.cpp index d952e48..aa1bf9f 100644 --- a/tests/pow_complex.cpp +++ b/tests/pow_complex.cpp @@ -1,13 +1,31 @@ #include "test_helper.hpp" -template -bool test_pow_cplx_cplx(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex pow cplx-cplx overload", "[pow]", double, + float, sycl::half) { + using T = TestType; - auto std_in1 = init_std_complex(init_re, init_im); - auto std_in2 = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input1{init_re, init_im}; - sycl::ext::cplx::complex cplx_input2{init_re, init_im}; + sycl::queue Q; + + // Test cases + // Values are generated as cross product of input1 and input2's GENERATE list + cmplx input1 = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + cmplx input2 = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in1 = init_std_complex(input1); + auto std_in2 = init_std_complex(input2); + sycl::ext::cplx::complex cplx_input1{input1.re, input1.im}; + sycl::ext::cplx::complex cplx_input2{input2.re, input2.im}; std::complex std_out{}; auto *cplx_out = sycl::malloc_shared>(1, Q); @@ -16,30 +34,42 @@ bool test_pow_cplx_cplx(sycl::queue &Q, T init_re, T init_im) { std_out = std::pow(std_in1, std_in2); // Check cplx::complex output from device - Q.single_task([=]() { - cplx_out[0] = sycl::ext::cplx::pow(cplx_input1, cplx_input2); - }).wait(); + if (is_type_supported(Q)) { + Q.single_task([=]() { + cplx_out[0] = sycl::ext::cplx::pow(cplx_input1, cplx_input2); + }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); + check_results(cplx_out[0], std_out); + } // Check cplx::complex output from host cplx_out[0] = sycl::ext::cplx::pow(cplx_input1, cplx_input2); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); + check_results(cplx_out[0], std_out); sycl::free(cplx_out, Q); - - return pass; } -template -bool test_pow_cplx_deci(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex pow cplx-deci overload", "[pow]", double, + float, sycl::half) { + using T = TestType; + + sycl::queue Q; - auto std_in = init_std_complex(init_re, init_im); - auto std_deci_in = init_deci(init_re); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; - T deci_input = init_re; + // Test cases + // Values are generated as cross product of input1 and input2's GENERATE list + cmplx input1 = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + T input2 = GENERATE(4.42, inf_val, nan_val); + + auto std_in = init_std_complex(input1); + auto std_deci_in = init_deci(input2); + sycl::ext::cplx::complex cplx_input{input1.re, input1.im}; + T deci_input = input2; std::complex std_out{}; auto *cplx_out = sycl::malloc_shared>(1, Q); @@ -48,28 +78,42 @@ bool test_pow_cplx_deci(sycl::queue &Q, T init_re, T init_im) { std_out = std::pow(std_in, std_deci_in); // Check cplx::complex output from device - Q.single_task([=]() { - cplx_out[0] = sycl::ext::cplx::pow(cplx_input, deci_input); - }).wait(); + if (is_type_supported(Q)) { + Q.single_task([=]() { + cplx_out[0] = sycl::ext::cplx::pow(cplx_input, deci_input); + }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); + check_results(cplx_out[0], std_out); + } // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::pow(cplx_input, deci_input); + cplx_out[0] = sycl::ext::cplx::pow(cplx_input, deci_input); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); + check_results(cplx_out[0], std_out); - return pass; + sycl::free(cplx_out, Q); } -template -bool test_pow_deci_cplx(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex pow deci-cplx overload", "[pow]", double, + float, sycl::half) { + using T = TestType; + + sycl::queue Q; - auto std_in = init_std_complex(init_re, init_im); - auto std_deci_in = init_deci(init_re); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; - T deci_input = init_re; + // Test cases + // Values are generated as cross product of input1 and input2's GENERATE list + cmplx input1 = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + T input2 = GENERATE(4.42, inf_val, nan_val); + + auto std_in = init_std_complex(input1); + auto std_deci_in = init_deci(input2); + sycl::ext::cplx::complex cplx_input{input1.re, input1.im}; + T deci_input = input2; std::complex std_out{}; auto *cplx_out = sycl::malloc_shared>(1, Q); @@ -78,51 +122,18 @@ bool test_pow_deci_cplx(sycl::queue &Q, T init_re, T init_im) { std_out = std::pow(std_deci_in, std_in); // Check cplx::complex output from device - Q.single_task([=]() { - cplx_out[0] = sycl::ext::cplx::pow(deci_input, cplx_input); - }).wait(); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); + if (is_type_supported(Q)) { + Q.single_task([=]() { + cplx_out[0] = sycl::ext::cplx::pow(deci_input, cplx_input); + }).wait(); - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::pow(deci_input, cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - return pass; -} - -template struct test_pow { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; - pass &= test_pow_cplx_cplx(Q, init_re, init_im); - pass &= test_pow_cplx_deci(Q, init_re, init_im); - pass &= test_pow_deci_cplx(Q, init_re, init_im); - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::pow(deci_input, cplx_input); - if (!test_passes) - std::cerr << "pow complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/sin_complex.cpp b/tests/sin_complex.cpp index 960ecb0..b6c8c38 100644 --- a/tests/sin_complex.cpp +++ b/tests/sin_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_sin { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex sin", "[sin]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::sin(std_in); + // Get std::complex output + std_out = std::sin(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::sin(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::sin(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::sin(cplx_input); - if (!test_passes) - std::cerr << "sin complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/sinh_complex.cpp b/tests/sinh_complex.cpp index 939f3d9..a11e1c2 100644 --- a/tests/sinh_complex.cpp +++ b/tests/sinh_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_sinh { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex sinh", "[sinh]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::sinh(std_in); + // Get std::complex output + std_out = std::sinh(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::sinh(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::sinh(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::sinh(cplx_input); - if (!test_passes) - std::cerr << "sinh complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/sqrt_complex.cpp b/tests/sqrt_complex.cpp index c29aaa7..32601bb 100644 --- a/tests/sqrt_complex.cpp +++ b/tests/sqrt_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_sqrt { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex sqrt", "[sqrt]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::sqrt(std_in); + // Get std::complex output + std_out = std::sqrt(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::sqrt(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::sqrt(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::sqrt(cplx_input); - if (!test_passes) - std::cerr << "sqrt complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/tan_complex.cpp b/tests/tan_complex.cpp index 4814a71..cc46157 100644 --- a/tests/tan_complex.cpp +++ b/tests/tan_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_tan { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex tan", "[tan]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::tan(std_in); + // Get std::complex output + std_out = std::tan(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::tan(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::tan(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::tan(cplx_input); - if (!test_passes) - std::cerr << "tan complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; + sycl::free(cplx_out, Q); } diff --git a/tests/tanh_complex.cpp b/tests/tanh_complex.cpp index 30625b3..bfe934d 100644 --- a/tests/tanh_complex.cpp +++ b/tests/tanh_complex.cpp @@ -1,57 +1,40 @@ #include "test_helper.hpp" -template struct test_tanh { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - bool pass = true; +TEMPLATE_TEST_CASE("Test complex tanh", "[tanh]", double, float, sycl::half) { + using T = TestType; - auto std_in = init_std_complex(init_re, init_im); - sycl::ext::cplx::complex cplx_input{init_re, init_im}; + sycl::queue Q; + + // Test cases + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); + + auto std_in = init_std_complex(input); + sycl::ext::cplx::complex cplx_input{input.re, input.im}; - std::complex std_out{}; - auto *cplx_out = sycl::malloc_shared>(1, Q); + std::complex std_out{}; + auto *cplx_out = sycl::malloc_shared>(1, Q); - // Get std::complex output - std_out = std::tanh(std_in); + // Get std::complex output + std_out = std::tanh(std_in); - // Check cplx::complex output from device + // Check cplx::complex output from device + if (is_type_supported(Q)) { Q.single_task([=]() { cplx_out[0] = sycl::ext::cplx::tanh(cplx_input); }).wait(); - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); - - // Check cplx::complex output from host - cplx_out[0] = sycl::ext::cplx::tanh(cplx_input); - - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); - - sycl::free(cplx_out, Q); - - return pass; + check_results(cplx_out[0], std_out); } -}; - -int main() { - sycl::queue Q; - - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02); - - test_passes &= test_valid_types(Q, INFINITY, 2.02); - test_passes &= test_valid_types(Q, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN); + // Check cplx::complex output from host + cplx_out[0] = sycl::ext::cplx::tanh(cplx_input); - if (!test_passes) - std::cerr << "tanh complex test fails\n"; + check_results(cplx_out[0], std_out); - return !test_passes; -} + sycl::free(cplx_out, Q); +} \ No newline at end of file diff --git a/tests/test_complex_types.cpp b/tests/test_complex_types.cpp index 73985a0..ccaf774 100644 --- a/tests/test_complex_types.cpp +++ b/tests/test_complex_types.cpp @@ -2,120 +2,75 @@ using namespace sycl::ext::cplx; +// Test checks user interface return types +// Compile time only tests, will fail during compilation due to static asserts + // Define math operations tests -#define TEST_MATH_OP_TYPE(op_name, op) \ - template struct test##_##op_name##_##types { \ - bool operator()() { \ +#define TEST_MATH_OP_TYPE(test_name, label, op) \ + TEMPLATE_TEST_CASE(test_name, label, double, float, sycl::half) { \ \ - static_assert( \ - std::is_same_v, decltype(std::declval>() \ - op std::declval())>); \ + static_assert(std::is_same_v, \ + decltype(std::declval>() \ + op std::declval())>); \ \ - static_assert( \ - std::is_same_v, \ - decltype(std::declval() \ - op std::declval>())>); \ - return true; \ - } \ - }; + static_assert( \ + std::is_same_v, \ + decltype(std::declval() \ + op std::declval>())>); \ + } -TEST_MATH_OP_TYPE(add, +) -TEST_MATH_OP_TYPE(sub, -) -TEST_MATH_OP_TYPE(mul, *) -TEST_MATH_OP_TYPE(div, /) +TEST_MATH_OP_TYPE("Test complex addition types", "[add]", +) +TEST_MATH_OP_TYPE("Test complex subtraction types", "[sub]", -) +TEST_MATH_OP_TYPE("Test complex multiplication types", "[mul]", *) +TEST_MATH_OP_TYPE("Test complex division types", "[div]", /) #undef TEST_MATH_FUNC_TYPE -// Check operations return correct types -void check_math_operator_types() { - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); -} - // Define math function tests -#define TEST_MATH_FUNC_TYPE(func) \ - template struct test##_##func##_##types { \ - bool operator()() { \ - static_assert(std::is_same_v, decltype(func(complex()))>); \ - return true; \ - } \ - }; - -TEST_MATH_FUNC_TYPE(acos) -TEST_MATH_FUNC_TYPE(asin) -TEST_MATH_FUNC_TYPE(atan) -TEST_MATH_FUNC_TYPE(acosh) -TEST_MATH_FUNC_TYPE(asinh) -TEST_MATH_FUNC_TYPE(atanh) -TEST_MATH_FUNC_TYPE(conj) -TEST_MATH_FUNC_TYPE(cos) -TEST_MATH_FUNC_TYPE(cosh) -TEST_MATH_FUNC_TYPE(exp) -TEST_MATH_FUNC_TYPE(log) -TEST_MATH_FUNC_TYPE(log10) -TEST_MATH_FUNC_TYPE(proj) -TEST_MATH_FUNC_TYPE(sin) -TEST_MATH_FUNC_TYPE(sinh) -TEST_MATH_FUNC_TYPE(sqrt) -TEST_MATH_FUNC_TYPE(tan) -TEST_MATH_FUNC_TYPE(tanh) -#undef TEST_MATH_FUNC_TYPE - -template struct test_abs_types { - bool operator()() { - static_assert(std::is_same_v()))>); - return true; +#define TEST_MATH_FUNC_TYPE(test_name, label, func) \ + TEMPLATE_TEST_CASE(test_name, label, double, float, sycl::half) { \ + static_assert(std::is_same_v, \ + decltype(func(complex()))>); \ } -}; -template struct test_polar_types { - bool operator()() { - static_assert(std::is_same_v, decltype(polar(T()))>); - static_assert(std::is_same_v, decltype(polar(T(), T()))>); - return true; - } -}; - -template struct test_pow_types { - bool operator()() { - static_assert(std::is_same_v, decltype(pow(complex(), T()))>); - static_assert( - std::is_same_v, decltype(pow(complex(), complex()))>); - static_assert(std::is_same_v, decltype(pow(T(), complex()))>); - return true; - } -}; - -// Check functions return correct types -void check_math_function_types() { +TEST_MATH_FUNC_TYPE("Test complex acos types", "[acos]", acos) +TEST_MATH_FUNC_TYPE("Test complex asin types", "[asin]", asin) +TEST_MATH_FUNC_TYPE("Test complex atan types", "[atan]", atan) +TEST_MATH_FUNC_TYPE("Test complex acosh types", "[acosh]", acosh) +TEST_MATH_FUNC_TYPE("Test complex asinh types", "[asinh]", asinh) +TEST_MATH_FUNC_TYPE("Test complex atanh types", "[atanh]", atanh) +TEST_MATH_FUNC_TYPE("Test complex conj types", "[conj]", conj) +TEST_MATH_FUNC_TYPE("Test complex cos types", "[cos]", cos) +TEST_MATH_FUNC_TYPE("Test complex cosh types", "[cosh]", cosh) +TEST_MATH_FUNC_TYPE("Test complex exp types", "[exp]", exp) +TEST_MATH_FUNC_TYPE("Test complex log types", "[log]", log) +TEST_MATH_FUNC_TYPE("Test complex log10 types", "[log10]", log10) +TEST_MATH_FUNC_TYPE("Test complex proj types", "[proj]", proj) +TEST_MATH_FUNC_TYPE("Test complex sin types", "[sin]", sin) +TEST_MATH_FUNC_TYPE("Test complex sinh types", "[sinh]", sinh) +TEST_MATH_FUNC_TYPE("Test complex sqrt types", "[sqrt]", sqrt) +TEST_MATH_FUNC_TYPE("Test complex tan types", "[tan]", tan) +TEST_MATH_FUNC_TYPE("Test complex tanh types", "[tanh]", tanh) +#undef TEST_MATH_FUNC_TYPE - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); - test_valid_types(); +TEMPLATE_TEST_CASE("Test complex abs types", "[abs]", double, float, + sycl::half) { + static_assert(std::is_same_v()))>); } -int main() { - check_math_function_types(); - check_math_operator_types(); +TEMPLATE_TEST_CASE("Test complex polar types", "[abs]", double, float, + sycl::half) { + static_assert(std::is_same_v, decltype(polar(TestType()))>); + static_assert(std::is_same_v, + decltype(polar(TestType(), TestType()))>); +} - return 0; +TEMPLATE_TEST_CASE("Test complex pow types", "[abs]", double, float, + sycl::half) { + static_assert(std::is_same_v, + decltype(pow(complex(), TestType()))>); + static_assert( + std::is_same_v, + decltype(pow(complex(), complex()))>); + static_assert(std::is_same_v, + decltype(pow(TestType(), complex()))>); } diff --git a/tests/test_gencomplex.cpp b/tests/test_gencomplex.cpp index ad52eff..64b0bd6 100644 --- a/tests/test_gencomplex.cpp +++ b/tests/test_gencomplex.cpp @@ -1,10 +1,9 @@ -#include "sycl_ext_complex.hpp" -#include +#include "test_helper.hpp" using namespace sycl::ext::cplx; // Check is_gencomplex -void check_is_gencomplex() { +TEST_CASE("Test is_gencomplex", "[gencomplex]") { static_assert(is_gencomplex>::value == true); static_assert(is_gencomplex>::value == true); static_assert(is_gencomplex>::value == true); @@ -16,8 +15,3 @@ void check_is_gencomplex() { static_assert(is_gencomplex>::value == false); static_assert(is_gencomplex>::value == false); } - -int main() { - check_is_gencomplex(); - return 0; -} diff --git a/tests/test_helper.hpp b/tests/test_helper.hpp index 85fa010..6e3fac3 100644 --- a/tests/test_helper.hpp +++ b/tests/test_helper.hpp @@ -1,3 +1,6 @@ +#include +#include +#include #include #include #include @@ -6,12 +9,42 @@ #define SYCL_CPLX_TOL_ULP 5 -// Helpers for displaying results +// Helpers for check if type is supported +template inline bool is_type_supported(sycl::queue &Q) { return false; } -template const char *get_typename() { return "Unknown type"; } -template <> const char *get_typename() { return "double"; } -template <> const char *get_typename() { return "float"; } -template <> const char *get_typename() { return "sycl::half"; } +template <> inline bool is_type_supported(sycl::queue &Q) { + return Q.get_device().has(sycl::aspect::fp64); +} + +template <> inline bool is_type_supported(sycl::queue &Q) { + return true; +} + +template <> inline bool is_type_supported(sycl::queue &Q) { + return Q.get_device().has(sycl::aspect::fp16); +} + +// Helper for passing infinity and nan values +template +inline constexpr T inf_val = std::numeric_limits::infinity(); + +template +inline constexpr T nan_val = std::numeric_limits::quiet_NaN(); + +// Helper class for passing complex arguments + +template struct cmplx { + constexpr cmplx() : re(0), im(0) {} + constexpr cmplx(T real, T imag) : re(real), im(imag) {} + + template cmplx(cmplx c) { + re = c.re; + im = c.im; + } + + T re; + T im; +}; // Helper for testing each decimal type @@ -88,17 +121,13 @@ bool almost_equal(T1 x, T2 y, int ulp) { } // Helpers for testing half -std::complex sycl_half_to_float(sycl::ext::cplx::complex c) { +inline std::complex +sycl_half_to_float(sycl::ext::cplx::complex c) { auto c_sycl_float = static_cast>(c); return static_cast>(c_sycl_float); } -std::complex sycl_float_to_half(std::complex c) { - auto c_sycl_half = static_cast>(c); - return static_cast>(c_sycl_half); -} - -std::complex trunc_float(std::complex c) { +inline std::complex trunc_float(std::complex c) { auto c_sycl_half = static_cast>(c); return sycl_half_to_float(c_sycl_half); } @@ -106,12 +135,12 @@ std::complex trunc_float(std::complex c) { // Helper for initializing std::complex values for tests only needed because // sycl::half cases are emulated with float for std::complex class -template auto constexpr init_std_complex(T_in re, T_in im) { - return std::complex(re, im); +template auto constexpr init_std_complex(cmplx c) { + return std::complex(c.re, c.im); } -template <> auto constexpr init_std_complex(sycl::half re, sycl::half im) { - return trunc_float(std::complex(re, im)); +template <> auto constexpr init_std_complex(cmplx c) { + return trunc_float(std::complex(c.re, c.im)); } template auto constexpr init_deci(T_in re) { return re; } @@ -123,30 +152,12 @@ template <> auto constexpr init_deci(sycl::half re) { // Helpers for comparing SyclCPLX and standard c++ results template -bool check_results(sycl::ext::cplx::complex output, - std::complex reference, bool is_device, - int tol_multiplier = 1) { - if (!almost_equal(output, reference, tol_multiplier * SYCL_CPLX_TOL_ULP)) { - std::cerr << std::setprecision(std::numeric_limits::max_digits10) - << "Test failed with complex_type: " << get_typename() - << " Computed on " << (is_device ? "device" : "host") - << " Output: " << output << " Reference: " << reference - << std::endl; - return false; - } - return true; +void check_results(sycl::ext::cplx::complex output, + std::complex reference, int tol_multiplier = 1) { + CHECK(almost_equal(output, reference, tol_multiplier * SYCL_CPLX_TOL_ULP)); } template -bool check_results(T output, T reference, bool is_device, - int tol_multiplier = 1) { - if (!almost_equal(output, reference, tol_multiplier * SYCL_CPLX_TOL_ULP)) { - std::cerr << std::setprecision(std::numeric_limits::max_digits10) - << "Test failed with complex_type: " << get_typename() - << " Computed on " << (is_device ? "device" : "host") - << " Output: " << output << " Reference: " << reference - << std::endl; - return false; - } - return true; +void check_results(T output, T reference, int tol_multiplier = 1) { + CHECK(almost_equal(output, reference, tol_multiplier * SYCL_CPLX_TOL_ULP)); } diff --git a/tests/test_operator_complex.cpp b/tests/test_operator_complex.cpp index f3f6b16..7b6ca71 100644 --- a/tests/test_operator_complex.cpp +++ b/tests/test_operator_complex.cpp @@ -1,343 +1,273 @@ #include "test_helper.hpp" -#define test_op(name, op) \ - template struct name { \ - bool operator()(sycl::queue &Q, T init_re1, T init_im1, T init_re2, \ - T init_im2) { \ - bool pass = true; \ +#define test_op(test_name, label, op) \ + TEMPLATE_TEST_CASE(test_name, label, double, float, sycl::half) { \ + using T = TestType; \ + using std::make_tuple; \ \ - auto std_in1 = init_std_complex(init_re1, init_im1); \ - auto std_in2 = init_std_complex(init_re2, init_im2); \ - sycl::ext::cplx::complex cplx_input1{init_re1, init_im1}; \ - sycl::ext::cplx::complex cplx_input2{init_re2, init_im2}; \ + sycl::queue Q; \ \ - std::complex std_out{}; \ - auto *cplx_out = sycl::malloc_shared>(1, Q); \ + cmplx input1 = GENERATE( \ + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, \ + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, \ + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, \ + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, \ + cmplx{inf_val, nan_val}); \ \ - /* Check complex-decimal op */ \ - T dec = init_re2; \ - std_out = std_in1 op init_deci(dec); \ + cmplx input2 = GENERATE(cmplx{4.42, 2.02}); \ \ - Q.single_task([=]() { cplx_out[0] = cplx_input1 op dec; }).wait(); \ + auto std_in1 = init_std_complex(input1); \ + auto std_in2 = init_std_complex(input2); \ + sycl::ext::cplx::complex cplx_input1{input1.re, input1.im}; \ + sycl::ext::cplx::complex cplx_input2{input2.re, input2.im}; \ \ - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); \ + std::complex std_out{}; \ + auto *cplx_out = sycl::malloc_shared>(1, Q); \ \ - cplx_out[0] = cplx_input1 op dec; \ + /* Check complex-complex op */ \ + std_out = std_in1 op std_in2; \ \ - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); \ + Q.single_task([=]() { cplx_out[0] = cplx_input1 op cplx_input2; }).wait(); \ \ - /* Check decimal-complex op */ \ - dec = init_re1; \ - std_out = init_deci(dec) op std_in2; \ + check_results(cplx_out[0], std_out); \ \ - Q.single_task([=]() { cplx_out[0] = dec op cplx_input2; }).wait(); \ + cplx_out[0] = cplx_input1 op cplx_input2; \ \ - pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); \ + check_results(cplx_out[0], std_out); \ \ - cplx_out[0] = dec op cplx_input2; \ - \ - pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); \ - \ - sycl::free(cplx_out, Q); \ - \ - return pass; \ - } \ - }; + sycl::free(cplx_out, Q); \ + } -test_op(test_add, +); -test_op(test_sub, -); -test_op(test_mul, *); -test_op(test_div, /); +test_op("Test complex addition cplx-cplx overload", "[add]", +); +test_op("Test complex subtraction cplx-cplx overload", "[sub]", +); +test_op("Test complex multiplication cplx-cplx overload", "[mul]", +); +test_op("Test complex division cplx-cplx overload", "[div]", +); #undef test_op -#define test_op_assign(name, op_assign) \ - template struct name { \ - bool operator()(sycl::queue &Q, T init_re1, T init_im1, T init_re2, \ - T init_im2) { \ - bool pass = true; \ +#define test_op(test_name, label, op) \ + TEMPLATE_TEST_CASE(test_name, label, double, float, sycl::half) { \ + using T = TestType; \ + using std::make_tuple; \ \ - auto std_in = init_std_complex(init_re1, init_im1); \ - sycl::ext::cplx::complex cplx_input{init_re1, init_im1}; \ - auto *cplx_inout = \ - sycl::malloc_shared>(1, Q); \ - /* Check complex-decimal op */ \ - auto std_inout = init_std_complex(init_re2, init_im2); \ - cplx_inout[0].real(init_re2); \ - cplx_inout[0].imag(init_im2); \ + sycl::queue Q; \ \ - std_inout op_assign std_in; \ + cmplx input1 = GENERATE( \ + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, \ + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, \ + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, \ + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, \ + cmplx{inf_val, nan_val}); \ \ - Q.single_task([=]() { cplx_inout[0] op_assign cplx_input; }).wait(); \ + T input2 = GENERATE(4.42, 2.02); \ \ - pass &= check_results( \ - cplx_inout[0], std::complex(std_inout.real(), std_inout.imag()), \ - /*is_device*/ true); \ + auto std_in = init_std_complex(input1); \ + auto std_deci_in = init_deci(input2); \ + sycl::ext::cplx::complex cplx_input{input1.re, input1.im}; \ + T deci_input = input2; \ \ - cplx_inout[0].real(init_re2); \ - cplx_inout[0].imag(init_im2); \ + std::complex std_out{}; \ + auto *cplx_out = sycl::malloc_shared>(1, Q); \ \ - cplx_inout[0] op_assign cplx_input; \ + /* Check complex-decimal op */ \ + std_out = std_in op std_deci_in; \ \ - pass &= check_results( \ - cplx_inout[0], std::complex(std_inout.real(), std_inout.imag()), \ - /*is_device*/ false); \ + Q.single_task([=]() { cplx_out[0] = cplx_input op deci_input; }).wait(); \ \ - /* Check complex-decimal op */ \ - std_inout = init_std_complex(init_re2, init_im2); \ - cplx_inout[0].real(init_re2); \ - cplx_inout[0].imag(init_im2); \ + check_results(cplx_out[0], std_out); \ \ - T dec = init_re1; \ - std_inout op_assign init_deci(dec); \ + cplx_out[0] = cplx_input op deci_input; \ \ - Q.single_task([=]() { cplx_inout[0] op_assign dec; }).wait(); \ + check_results(cplx_out[0], std_out); \ \ - pass &= check_results( \ - cplx_inout[0], std::complex(std_inout.real(), std_inout.imag()), \ - /*is_device*/ true); \ + sycl::free(cplx_out, Q); \ + } + +test_op("Test complex addition cplx-deci overload", "[add]", +); +test_op("Test complex subtraction cplx-deci overload", "[sub]", -); +test_op("Test complex multiplication cplx-deci overload", "[mul]", *); +test_op("Test complex division cplx-deci overload", "[div]", /); + +#undef test_op + +#define test_op(test_name, label, op) \ + TEMPLATE_TEST_CASE(test_name, label, double, float, sycl::half) { \ + using T = TestType; \ + using std::make_tuple; \ \ - cplx_inout[0].real(init_re2); \ - cplx_inout[0].imag(init_im2); \ + sycl::queue Q; \ \ - cplx_inout[0] op_assign dec; \ + cmplx input1 = GENERATE( \ + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, \ + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, \ + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, \ + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, \ + cmplx{inf_val, nan_val}); \ \ - pass &= check_results( \ - cplx_inout[0], std::complex(std_inout.real(), std_inout.imag()), \ - /*is_device*/ false); \ + T input2 = GENERATE(4.42, 2.02); \ \ - sycl::free(cplx_inout, Q); \ - return pass; \ + auto std_in = init_std_complex(input1); \ + auto std_deci_in = init_deci(input2); \ + sycl::ext::cplx::complex cplx_input{input1.re, input1.im}; \ + T deci_input = input2; \ + \ + std::complex std_out{}; \ + auto *cplx_out = sycl::malloc_shared>(1, Q); \ + \ + /* Check complex-decimal op */ \ + std_out = std_deci_in op std_in; \ + \ + if (is_type_supported(Q)) { \ + Q.single_task([=]() { cplx_out[0] = deci_input op cplx_input; }).wait(); \ + \ + check_results(cplx_out[0], std_out); \ } \ - }; - -test_op_assign(test_add_assign, +=); -test_op_assign(test_sub_assign, -=); -test_op_assign(test_mul_assign, *=); -test_op_assign(test_div_assign, /=); - -#undef test_op_assign - -int main() { - sycl::queue Q; - - bool test_failed = false; - - { - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02, -1.5, 3.2); - - test_passes &= - test_valid_types(Q, INFINITY, 2.02, INFINITY, 2.02); - test_passes &= - test_valid_types(Q, 4.42, INFINITY, 4.42, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, INFINITY, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN, NAN, NAN); - - test_passes &= test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - if (!test_passes) { - std::cerr << "Addition operator complex test fails\n"; - test_failed = true; - } - } - - { - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02, -1.5, 3.2); - - test_passes &= - test_valid_types(Q, INFINITY, 2.02, INFINITY, 2.02); - test_passes &= - test_valid_types(Q, 4.42, INFINITY, 4.42, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, INFINITY, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN, NAN, NAN); - - test_passes &= test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - if (!test_passes) { - std::cerr << "Subtraction operator complex test fails\n"; - test_failed = true; - } - } - - { - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02, -1.5, 3.2); - - test_passes &= - test_valid_types(Q, INFINITY, 2.02, INFINITY, 2.02); - test_passes &= - test_valid_types(Q, 4.42, INFINITY, 4.42, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, INFINITY, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN, NAN, NAN); - - test_passes &= test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - if (!test_passes) { - std::cerr << "Multiplication operator complex test fails\n"; - test_failed = true; - } - } - - { - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02, -1.5, 3.2); - - test_passes &= - test_valid_types(Q, INFINITY, 2.02, INFINITY, 2.02); - test_passes &= - test_valid_types(Q, 4.42, INFINITY, 4.42, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, INFINITY, INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN, NAN, NAN); - - test_passes &= test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - test_passes &= test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - if (!test_passes) { - std::cerr << "Division operator complex test fails\n"; - test_failed = true; - } - } - - { - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02, -1.5, 3.2); - - test_passes &= - test_valid_types(Q, INFINITY, 2.02, INFINITY, 2.02); - test_passes &= - test_valid_types(Q, 4.42, INFINITY, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, - INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN, NAN, NAN); - - test_passes &= - test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - test_passes &= - test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - if (!test_passes) { - std::cerr << "Addition assign operator complex test fails\n"; - test_failed = true; - } + \ + cplx_out[0] = deci_input op cplx_input; \ + \ + check_results(cplx_out[0], std_out); \ + \ + sycl::free(cplx_out, Q); \ } - { - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02, -1.5, 3.2); - - test_passes &= - test_valid_types(Q, INFINITY, 2.02, INFINITY, 2.02); - test_passes &= - test_valid_types(Q, 4.42, INFINITY, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, - INFINITY, INFINITY); +test_op("Test complex addition deci-cplx overload", "[add]", +); +test_op("Test complex subtraction deci-cplx overload", "[sub]", -); +test_op("Test complex multiplication deci-cplx overload", "[mul]", *); +test_op("Test complex division deci-cplx overload", "[div]", /); - test_passes &= test_valid_types(Q, NAN, 2.02, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN, NAN, NAN); +#undef test_op - test_passes &= - test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - test_passes &= - test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - if (!test_passes) { - std::cerr << "Subtraction assign operator complex test fails\n"; - test_failed = true; - } +// OP assign tests are checked for the all possible type combinations as op +// assign supports different types being used. + +#define test_op_assign(test_name, label, op_assign) \ + TEMPLATE_TEST_CASE( \ + test_name, label, (std::tuple), \ + (std::tuple), (std::tuple), \ + (std::tuple), (std::tuple), \ + (std::tuple), (std::tuple), \ + (std::tuple), (std::tuple)) { \ + using T1 = typename std::tuple_element<0, TestType>::type; \ + using T2 = typename std::tuple_element<0, TestType>::type; \ + using std::make_tuple; \ + \ + sycl::queue Q; \ + \ + cmplx input1 = GENERATE( \ + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, \ + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, \ + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, \ + cmplx{nan_val, nan_val}, \ + cmplx{nan_val, inf_val}, \ + cmplx{inf_val, nan_val}); \ + \ + cmplx input2 = GENERATE( \ + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, \ + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, \ + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, \ + cmplx{nan_val, nan_val}, \ + cmplx{nan_val, inf_val}, \ + cmplx{inf_val, nan_val}); \ + \ + auto std_in = init_std_complex(input1); \ + sycl::ext::cplx::complex cplx_input{input1.re, input1.im}; \ + \ + auto *cplx_inout = \ + sycl::malloc_shared>(1, Q); \ + \ + auto std_inout = init_std_complex(input2); \ + cplx_inout[0].real(input2.re); \ + cplx_inout[0].imag(input2.im); \ + \ + std_inout op_assign std_in; \ + \ + SECTION("DEVICE") { \ + if (is_type_supported(Q) && is_type_supported(Q)) { \ + Q.single_task([=]() { cplx_inout[0] op_assign cplx_input; }).wait(); \ + \ + check_results(cplx_inout[0], \ + std::complex(std_inout.real(), std_inout.imag())); \ + } \ + } \ + \ + SECTION("HOST") { \ + cplx_inout[0] op_assign cplx_input; \ + \ + check_results(cplx_inout[0], \ + std::complex(std_inout.real(), std_inout.imag())); \ + } \ + \ + sycl::free(cplx_inout, Q); \ } - { - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02, -1.5, 3.2); - - test_passes &= - test_valid_types(Q, INFINITY, 2.02, INFINITY, 2.02); - test_passes &= - test_valid_types(Q, 4.42, INFINITY, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, - INFINITY, INFINITY); +test_op_assign("Test complex assign addition cplx-cplx overload", "[add]", +=); +test_op_assign("Test complex assign subtraction cplx-cplx overload", "[sub]", + -=); +test_op_assign("Test complex assign multiplication cplx-cplx overload", "[mul]", + *=); +test_op_assign("Test complex assign division cplx-cplx overload", "[div]", /=); - test_passes &= test_valid_types(Q, NAN, 2.02, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN, NAN, NAN); +#undef test_op_assign - test_passes &= - test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - test_passes &= - test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - if (!test_passes) { - std::cerr << "Multiplication assign operator complex test fails\n"; - test_failed = true; - } +#define test_op_assign(test_name, label, op_assign) \ + TEMPLATE_TEST_CASE( \ + test_name, label, (std::tuple), \ + (std::tuple), (std::tuple), \ + (std::tuple), (std::tuple), \ + (std::tuple), (std::tuple), \ + (std::tuple), (std::tuple)) { \ + using T1 = typename std::tuple_element<0, TestType>::type; \ + using T2 = typename std::tuple_element<0, TestType>::type; \ + using std::make_tuple; \ + \ + sycl::queue Q; \ + \ + T1 input1 = GENERATE(4.42, 2.02, inf_val, nan_val); \ + \ + cmplx input2 = GENERATE( \ + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, \ + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, \ + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, \ + cmplx{nan_val, nan_val}, \ + cmplx{nan_val, inf_val}, \ + cmplx{inf_val, nan_val}); \ + \ + auto std_deci_in = init_deci(input1); \ + T1 deci_input = input1; \ + \ + auto *cplx_inout = \ + sycl::malloc_shared>(1, Q); \ + \ + auto std_inout = init_std_complex(input2); \ + cplx_inout[0].real(input2.re); \ + cplx_inout[0].imag(input2.im); \ + \ + std_inout op_assign std_deci_in; \ + \ + SECTION("DEVICE") { \ + if (is_type_supported(Q) && is_type_supported(Q)) { \ + Q.single_task([=]() { cplx_inout[0] op_assign deci_input; }).wait(); \ + \ + check_results(cplx_inout[0], \ + std::complex(std_inout.real(), std_inout.imag())); \ + } \ + } \ + \ + SECTION("HOST") { \ + cplx_inout[0] op_assign deci_input; \ + \ + check_results(cplx_inout[0], \ + std::complex(std_inout.real(), std_inout.imag())); \ + } \ + \ + sycl::free(cplx_inout, Q); \ } - { - bool test_passes = true; - test_passes &= test_valid_types(Q, 4.42, 2.02, -1.5, 3.2); - - test_passes &= - test_valid_types(Q, INFINITY, 2.02, INFINITY, 2.02); - test_passes &= - test_valid_types(Q, 4.42, INFINITY, 4.42, INFINITY); - test_passes &= test_valid_types(Q, INFINITY, INFINITY, - INFINITY, INFINITY); - - test_passes &= test_valid_types(Q, NAN, 2.02, NAN, 2.02); - test_passes &= test_valid_types(Q, 4.42, NAN, 4.42, NAN); - test_passes &= test_valid_types(Q, NAN, NAN, NAN, NAN); - - test_passes &= - test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - test_passes &= - test_valid_types(Q, NAN, INFINITY, NAN, INFINITY); - test_passes &= - test_valid_types(Q, INFINITY, NAN, INFINITY, NAN); - if (!test_passes) { - std::cerr << "Division assign operator complex test fails\n"; - test_failed = true; - } - } +test_op_assign("Test complex assign addition cplx-deci overload", "[add]", +=); +test_op_assign("Test complex assign subtraction cplx-deci overload", "[sub]", + -=); +test_op_assign("Test complex assign multiplication cplx-deci overload", "[mul]", + *=); +test_op_assign("Test complex assign division cplx-deci overload", "[div]", /=); - return test_failed; -} +#undef test_op_assign \ No newline at end of file diff --git a/tests/test_stream_operator.cpp b/tests/test_stream_operator.cpp index ab3964f..38d618c 100644 --- a/tests/test_stream_operator.cpp +++ b/tests/test_stream_operator.cpp @@ -1,76 +1,77 @@ #include "test_helper.hpp" -template struct test_sycl_stream_operator { - bool operator()(sycl::queue &Q, T init_re, T init_im) { - auto *cplx_out = sycl::malloc_shared>(1, Q); - cplx_out[0] = sycl::ext::cplx::complex(init_re, init_im); - - Q.submit([&](sycl::handler &CGH) { - sycl::stream Out(512, 20, CGH); - CGH.parallel_for<>(sycl::range<1>(1), [=](sycl::id<1> idx) { - Out << cplx_out[idx] << sycl::endl; - }); - }); +TEMPLATE_TEST_CASE("Test complex sycl stream", "[sycl::stream]", double, float, + sycl::half) { + using T = TestType; - sycl::free(cplx_out, Q); + sycl::queue Q; - return true; - } -}; + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); -// Host only tests for std::basic_ostream and std::basic_istream -template struct test_ostream_operator { - bool operator()(T init_re, T init_im) { - sycl::ext::cplx::complex c(init_re, init_im); + auto *cplx_out = sycl::malloc_shared>(1, Q); + cplx_out[0] = sycl::ext::cplx::complex(input.re, input.im); + + Q.submit([&](sycl::handler &CGH) { + sycl::stream Out(512, 20, CGH); + CGH.parallel_for<>(sycl::range<1>(1), [=](sycl::id<1> idx) { + Out << cplx_out[idx] << sycl::endl; + }); + }); - std::ostringstream os; - os << c; + sycl::free(cplx_out, Q); +} - std::ostringstream ref_oss; - ref_oss << std::complex(init_re, init_im); +// Host only tests for std::basic_ostream and std::basic_istream +TEMPLATE_TEST_CASE("Test complex std ostream", "[ostream]", double, float, + sycl::half) { + using T = TestType; - if (ref_oss.str() == os.str()) - return true; - return false; - } -}; + sycl::queue Q; -template struct test_istream_operator { - bool operator()(T init_re, T init_im) { - sycl::ext::cplx::complex c(init_re, init_im); + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); - std::ostringstream ref_oss; - ref_oss << "(" << init_re << "," << init_im << ")"; + sycl::ext::cplx::complex c(input.re, input.im); - std::istringstream iss(ref_oss.str()); + std::ostringstream os; + os << c; - iss >> c; + std::ostringstream ref_oss; + ref_oss << std::complex(input.re, input.im); + + CHECK(ref_oss.str() == os.str()); +} - return check_results(c, std::complex(init_re, init_im), - /*is_device*/ false); - } -}; +TEMPLATE_TEST_CASE("Test complex std istream", "[istream]", double, float, + sycl::half) { + using T = TestType; -int main() { sycl::queue Q; - bool test_passes = true; + cmplx input = GENERATE( + cmplx{4.42, 2.02}, cmplx{inf_val, 2.02}, + cmplx{4.42, inf_val}, cmplx{inf_val, inf_val}, + cmplx{nan_val, 2.02}, cmplx{4.42, nan_val}, + cmplx{nan_val, nan_val}, cmplx{nan_val, inf_val}, + cmplx{inf_val, nan_val}); - test_passes &= test_valid_types(Q, 0.42, -1.1); - test_passes &= - test_valid_types(Q, INFINITY, INFINITY); - test_passes &= test_valid_types(Q, NAN, NAN); + sycl::ext::cplx::complex c(input.re, input.im); - test_passes &= test_valid_types(0.42, -1.1); - test_passes &= test_valid_types(INFINITY, INFINITY); - test_passes &= test_valid_types(NAN, NAN); + std::ostringstream ref_oss; + ref_oss << "(" << input.re << "," << input.im << ")"; - test_passes &= test_valid_types(0.42, -1.1); - test_passes &= test_valid_types(INFINITY, INFINITY); - test_passes &= test_valid_types(NAN, NAN); + std::istringstream iss(ref_oss.str()); - if (!test_passes) - std::cerr << "Stream operator with complex test fails\n"; + iss >> c; - return 0; + check_results(c, std::complex(input.re, input.im)); } diff --git a/vendor/Catch2 b/vendor/Catch2 new file mode 160000 index 0000000..dc001fa --- /dev/null +++ b/vendor/Catch2 @@ -0,0 +1 @@ +Subproject commit dc001fa935d71b4b77f263fce405c9dbdfcbfe28