Skip to content

Commit

Permalink
Implement math functions for thrust::complex (#1178) (#1191)
Browse files Browse the repository at this point in the history
* Implement math functions for `thrust::complex`

We are having issues that the `cuda::std` math functions that take a `cuda::std::complex` return a `cuda::std::complex`. This can lead to issues as we require a conversion sequence from `cuda::std::complex` to `thrust::complex` which e.g is broken by an constructor being explicit.

Addresses nvbug4397241
  • Loading branch information
miscco authored Dec 8, 2023
1 parent 265d985 commit 665b376
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 1 deletion.
18 changes: 18 additions & 0 deletions thrust/testing/complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,21 @@ struct TestComplexStdComplexDeviceInterop
SimpleUnitTest<TestComplexStdComplexDeviceInterop, FloatingPointTypes>
TestComplexStdComplexDeviceInteropInstance;
#endif

template <typename T>
struct TestComplexExplicitConstruction
{
struct user_complex {
__host__ __device__ user_complex(T, T) {}
__host__ __device__ user_complex(const thrust::complex<T>&) {}
};

void operator()()
{
const thrust::complex<T> input(42.0, 1337.0);
const user_complex result = thrust::exp(input);
(void)result;
}
};
SimpleUnitTest<TestComplexExplicitConstruction, FloatingPointTypes>
TestComplexExplicitConstructionInstance;
89 changes: 88 additions & 1 deletion thrust/thrust/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,14 @@ operator/(const T0 &x, const complex<T1> &y)

// The using declarations allows imports all necessary functions for thurst::complex.
// However, they also lead to thrust::abs(1.0F) being valid code after include <thurst/complex.h>.
// We are importing those for the plain value taking overloads and specialize for those taking
// or returning a `thrust::complex` below
using ::cuda::std::abs;
using ::cuda::std::arg;
using ::cuda::std::conj;
using ::cuda::std::norm;
using ::cuda::std::polar;
// polar only takes a T but returns a complex<T> so we cannot pull that one in.
// using ::cuda::std::polar;
using ::cuda::std::proj;

using ::cuda::std::exp;
Expand All @@ -487,6 +490,90 @@ using ::cuda::std::sinh;
using ::cuda::std::tan;
using ::cuda::std::tanh;

// Those functions return `cuda::std::complex<T>` so we must provide an explicit overload that returns `thrust::complex<T>`
template<class T>
__host__ __device__ complex<T> conj(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::conj(c));
}
template<class T>
__host__ __device__ complex<T> polar(const T& rho, const T& theta = T{}) {
return static_cast<complex<T>>(::cuda::std::polar(rho, theta));
}
template<class T>
__host__ __device__ complex<T> proj(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::proj(c));
}

template<class T>
__host__ __device__ complex<T> exp(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::exp(c));
}
template<class T>
__host__ __device__ complex<T> log(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::log(c));
}
template<class T>
__host__ __device__ complex<T> log10(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::log10(c));
}
template<class T>
__host__ __device__ complex<T> pow(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::pow(c));
}
template<class T>
__host__ __device__ complex<T> sqrt(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::sqrt(c));
}

template<class T>
__host__ __device__ complex<T> acos(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::acos(c));
}
template<class T>
__host__ __device__ complex<T> acosh(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::acosh(c));
}
template<class T>
__host__ __device__ complex<T> asin(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::asin(c));
}
template<class T>
__host__ __device__ complex<T> asinh(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::asinh(c));
}
template<class T>
__host__ __device__ complex<T> atan(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::atan(c));
}
template<class T>
__host__ __device__ complex<T> atanh(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::atanh(c));
}
template<class T>
__host__ __device__ complex<T> cos(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::cos(c));
}
template<class T>
__host__ __device__ complex<T> cosh(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::cosh(c));
}
template<class T>
__host__ __device__ complex<T> sin(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::sin(c));
}
template<class T>
__host__ __device__ complex<T> sinh(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::sinh(c));
}
template<class T>
__host__ __device__ complex<T> tan(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::tan(c));
}
template<class T>
__host__ __device__ complex<T> tanh(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::tanh(c));
}

template <typename T>
struct proclaim_trivially_relocatable<complex<T>> : thrust::true_type
{};
Expand Down

0 comments on commit 665b376

Please sign in to comment.