diff --git a/thrust/testing/complex.cu b/thrust/testing/complex.cu index 3d5b26cbb0e..8b3c020f76c 100644 --- a/thrust/testing/complex.cu +++ b/thrust/testing/complex.cu @@ -679,3 +679,21 @@ struct TestComplexStdComplexDeviceInterop SimpleUnitTest TestComplexStdComplexDeviceInteropInstance; #endif + +template +struct TestComplexExplicitConstruction +{ + struct user_complex { + __host__ __device__ user_complex(T, T) {} + __host__ __device__ user_complex(const thrust::complex&) {} + }; + + void operator()() + { + const thrust::complex input(42.0, 1337.0); + const user_complex result = thrust::exp(input); + (void)result; + } +}; +SimpleUnitTest + TestComplexExplicitConstructionInstance; diff --git a/thrust/thrust/complex.h b/thrust/thrust/complex.h index cb296637158..3b9a1dc6e87 100644 --- a/thrust/thrust/complex.h +++ b/thrust/thrust/complex.h @@ -461,11 +461,14 @@ operator/(const T0 &x, const complex &y) // The using declarations allows imports all necessary functions for thurst::complex. // However, they also lead to thrust::abs(1.0F) being valid code after include . +// We are importing those for the plain value taking overloads and specialize for those taking +// or returning a `thrust::complex` below using ::cuda::std::abs; using ::cuda::std::arg; using ::cuda::std::conj; using ::cuda::std::norm; -using ::cuda::std::polar; +// polar only takes a T but returns a complex so we cannot pull that one in. +// using ::cuda::std::polar; using ::cuda::std::proj; using ::cuda::std::exp; @@ -487,6 +490,90 @@ using ::cuda::std::sinh; using ::cuda::std::tan; using ::cuda::std::tanh; +// Those functions return `cuda::std::complex` so we must provide an explicit overload that returns `thrust::complex` +template +__host__ __device__ complex conj(const complex& c) { + return static_cast>(::cuda::std::conj(c)); +} +template +__host__ __device__ complex polar(const T& rho, const T& theta = T{}) { + return static_cast>(::cuda::std::polar(rho, theta)); +} +template +__host__ __device__ complex proj(const complex& c) { + return static_cast>(::cuda::std::proj(c)); +} + +template +__host__ __device__ complex exp(const complex& c) { + return static_cast>(::cuda::std::exp(c)); +} +template +__host__ __device__ complex log(const complex& c) { + return static_cast>(::cuda::std::log(c)); +} +template +__host__ __device__ complex log10(const complex& c) { + return static_cast>(::cuda::std::log10(c)); +} +template +__host__ __device__ complex pow(const complex& c) { + return static_cast>(::cuda::std::pow(c)); +} +template +__host__ __device__ complex sqrt(const complex& c) { + return static_cast>(::cuda::std::sqrt(c)); +} + +template +__host__ __device__ complex acos(const complex& c) { + return static_cast>(::cuda::std::acos(c)); +} +template +__host__ __device__ complex acosh(const complex& c) { + return static_cast>(::cuda::std::acosh(c)); +} +template +__host__ __device__ complex asin(const complex& c) { + return static_cast>(::cuda::std::asin(c)); +} +template +__host__ __device__ complex asinh(const complex& c) { + return static_cast>(::cuda::std::asinh(c)); +} +template +__host__ __device__ complex atan(const complex& c) { + return static_cast>(::cuda::std::atan(c)); +} +template +__host__ __device__ complex atanh(const complex& c) { + return static_cast>(::cuda::std::atanh(c)); +} +template +__host__ __device__ complex cos(const complex& c) { + return static_cast>(::cuda::std::cos(c)); +} +template +__host__ __device__ complex cosh(const complex& c) { + return static_cast>(::cuda::std::cosh(c)); +} +template +__host__ __device__ complex sin(const complex& c) { + return static_cast>(::cuda::std::sin(c)); +} +template +__host__ __device__ complex sinh(const complex& c) { + return static_cast>(::cuda::std::sinh(c)); +} +template +__host__ __device__ complex tan(const complex& c) { + return static_cast>(::cuda::std::tan(c)); +} +template +__host__ __device__ complex tanh(const complex& c) { + return static_cast>(::cuda::std::tanh(c)); +} + template struct proclaim_trivially_relocatable> : thrust::true_type {};