Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT]: Add missing overloads for thrust::pow #1223

Merged
merged 2 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions thrust/testing/complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -449,17 +449,18 @@ struct TestComplexBasicArithmetic
// Test the basic arithmetic functions against std

ASSERT_ALMOST_EQUAL(thrust::abs(a), std::abs(b));

ASSERT_ALMOST_EQUAL(thrust::arg(a), std::arg(b));

ASSERT_ALMOST_EQUAL(thrust::norm(a), std::norm(b));

ASSERT_EQUAL(thrust::conj(a), std::conj(b));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::conj(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::polar(data[0], data[1]), std::polar(data[0], data[1]));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::polar(data[0], data[1]))>::value, "");

// random_samples does not seem to produce infinities so proj(z) == z
ASSERT_EQUAL(thrust::proj(a), a);
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::proj(a))>::value, "");
}
};
SimpleUnitTest<TestComplexBasicArithmetic, FloatingPointTypes> TestComplexBasicArithmeticInstance;
Expand Down Expand Up @@ -556,6 +557,9 @@ struct TestComplexExponentialFunctions
ASSERT_ALMOST_EQUAL(thrust::exp(a), std::exp(b));
ASSERT_ALMOST_EQUAL(thrust::log(a), std::log(b));
ASSERT_ALMOST_EQUAL(thrust::log10(a), std::log10(b));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::exp(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::log(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::log10(a))>::value, "");
}
};
SimpleUnitTest<TestComplexExponentialFunctions, FloatingPointTypes>
Expand All @@ -575,16 +579,24 @@ struct TestComplexPowerFunctions
const std::complex<T> b_std(b_thrust);

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real()));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, 4), std::pow(a_std, 4));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, 4))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::sqrt(a_thrust), std::sqrt(a_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sqrt(a_thrust))>::value, "");
}

// Test power functions with promoted types.
{
using T0 = T;
using T1 = other_floating_point_type_t<T0>;
using promoted = typename thrust::detail::promoted_numerical_type<T0, T1>::type;

thrust::host_vector<T0> data = unittest::random_samples<T0>(4);

Expand All @@ -594,11 +606,17 @@ struct TestComplexPowerFunctions
const std::complex<T0> b_std(data[2], data[3]);

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust, b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust), std::pow(b_std, a_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust, a_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real()));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust.real()), std::pow(b_std, a_std.real()));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust, a_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust.real(), a_thrust), std::pow(b_std.real(), a_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust.real(), a_thrust))>::value, "");
}
}
};
Expand All @@ -617,20 +635,32 @@ struct TestComplexTrigonometricFunctions
ASSERT_ALMOST_EQUAL(thrust::cos(a), std::cos(c));
ASSERT_ALMOST_EQUAL(thrust::sin(a), std::sin(c));
ASSERT_ALMOST_EQUAL(thrust::tan(a), std::tan(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::cos(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sin(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::tan(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::cosh(a), std::cosh(c));
ASSERT_ALMOST_EQUAL(thrust::sinh(a), std::sinh(c));
ASSERT_ALMOST_EQUAL(thrust::tanh(a), std::tanh(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::cosh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sinh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::tanh(a))>::value, "");

#if THRUST_CPP_DIALECT >= 2011

ASSERT_ALMOST_EQUAL(thrust::acos(a), std::acos(c));
ASSERT_ALMOST_EQUAL(thrust::asin(a), std::asin(c));
ASSERT_ALMOST_EQUAL(thrust::atan(a), std::atan(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::acos(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::asin(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::atan(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::acosh(a), std::acosh(c));
ASSERT_ALMOST_EQUAL(thrust::asinh(a), std::asinh(c));
ASSERT_ALMOST_EQUAL(thrust::atanh(a), std::atanh(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::acosh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::asinh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::atanh(a))>::value, "");

#endif
}
Expand Down
21 changes: 16 additions & 5 deletions thrust/thrust/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ using ::cuda::std::proj;
using ::cuda::std::exp;
using ::cuda::std::log;
using ::cuda::std::log10;
using ::cuda::std::pow;
// pow always returns a complex.
// using ::cuda::std::pow;
using ::cuda::std::sqrt;

using ::cuda::std::acos;
Expand Down Expand Up @@ -516,15 +517,25 @@ template<class T>
__host__ __device__ complex<T> log10(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::log10(c));
}
template<class T>
__host__ __device__ complex<T> pow(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::pow(c));
template<class T0, class T1>
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const complex<T0>& x, const complex<T1>& y) {
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
}
template<class T0, class T1, ::cuda::std::__enable_if_t<::cuda::std::is_arithmetic<T1>::value, int> = 0>
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const complex<T0>& x, const T1& y) {
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
}
template<class T0, class T1, ::cuda::std::__enable_if_t<::cuda::std::is_arithmetic<T0>::value, int> = 0>
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const T0& x, const complex<T1>& y) {
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
}
template<class T>
__host__ __device__ complex<T> sqrt(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::sqrt(c));
}

template<class T>
__host__ __device__ complex<T> acos(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::acos(c));
Expand Down
Loading