Skip to content

Commit 1f6e4bc

Browse files
Blonckmiscco
andauthored
Refactor thrust::complex as a struct derived from cuda::std::complex (NVIDIA#454)
* Replace `thrust::complex` with `std::complex` There are some notable differences though. thrust::complex has been a bit more lenient when determining the type of arithmetic operations. That said, I believe being more strict is actually a feature not a bug * Refactor thrust::complex as a struct derived from cuda::std::complex This commit refactors the thrust::complex type to be a struct derived from cuda::std::complex, enabling reuse of existing implementation logic. However, to maintain backward compatibility, certain operators are reintroduced to allow type promotion between 'float' and 'double' for the underlying type. * Make `thrust::complex` compile * Remove obsolete test * Fix complex build for gcc-12 * Use template evaluation short circuit for `complex` assignment operator. * [skip-tests] Update the license after complete reimplementation of complex --------- Co-authored-by: Michael Schellenberger Costa <[email protected]>
1 parent 69af06d commit 1f6e4bc

File tree

29 files changed

+552
-5549
lines changed

29 files changed

+552
-5549
lines changed

libcudacxx/.upstream-tests/test/std/numerics/complex.number/complex/types.fail.cpp

-41
This file was deleted.

libcudacxx/include/cuda/std/detail/libcxx/include/complex

+65-15
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,6 @@ template<class T> complex<T> tanh (const complex<T>&);
235235
#ifndef __cuda_std__
236236
#include <__config>
237237
#include <stdexcept>
238-
#if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION)
239-
# include <sstream> // for _CUDA_VSTD::basic_ostringstream
240-
#endif
241238
#endif // __cuda_std__
242239

243240
#include "__assert" // all public C++ headers provide the assertion handler
@@ -253,6 +250,11 @@ template<class T> complex<T> tanh (const complex<T>&);
253250
#include "type_traits"
254251
#include "version"
255252

253+
#if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION) \
254+
&& !defined(_LIBCUDACXX_COMPILER_NVRTC)
255+
#include <sstream> // for std::basic_ostringstream
256+
#endif // !_LIBCUDACXX_HAS_NO_LOCALIZATION && !_LIBCUDACXX_COMPILER_NVRTC
257+
256258
// Compatability helpers for thrust to convert between `std::complex` and `cuda::std::complex`
257259
#if defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC) && !defined(_LIBCUDACXX_COMPILER_MSVC)
258260
#include <complex>
@@ -407,8 +409,10 @@ public:
407409
: __re_(__re), __im_(__im) {}
408410
_LIBCUDACXX_INLINE_VISIBILITY
409411
explicit constexpr complex(const complex<double>& __c);
412+
#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
410413
_LIBCUDACXX_INLINE_VISIBILITY
411414
explicit constexpr complex(const complex<long double>& __c);
415+
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
412416

413417
#if defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC) && !defined(_LIBCUDACXX_COMPILER_MSVC)
414418
template <class _Up>
@@ -502,8 +506,11 @@ public:
502506
: __re_(__re), __im_(__im) {}
503507
_LIBCUDACXX_INLINE_VISIBILITY
504508
constexpr complex(const complex<float>& __c);
509+
510+
#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
505511
_LIBCUDACXX_INLINE_VISIBILITY
506512
explicit constexpr complex(const complex<long double>& __c);
513+
#endif //_LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
507514

508515
#if defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC) && !defined(_LIBCUDACXX_COMPILER_MSVC)
509516
template <class _Up>
@@ -585,20 +592,10 @@ public:
585592
}
586593
};
587594

595+
#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
588596
template<>
589597
class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(2*sizeof(long double)) complex<long double>
590598
{
591-
#ifndef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
592-
public:
593-
template <typename _Dummy = void>
594-
_LIBCUDACXX_INLINE_VISIBILITY constexpr complex(long double __re = 0.0, long double __im = 0.0)
595-
{static_assert(is_same<_Dummy, void>::value, "complex<long double> is not supported");}
596-
597-
template <typename _Tp, typename _Dummy = void>
598-
_LIBCUDACXX_INLINE_VISIBILITY constexpr complex(const complex<_Tp> &__c)
599-
{static_assert(is_same<_Dummy, void>::value, "complex<long double> is not supported");}
600-
601-
#else
602599
long double __re_;
603600
long double __im_;
604601
public:
@@ -689,8 +686,8 @@ public:
689686
*this = *this / complex(__c.real(), __c.imag());
690687
return *this;
691688
}
692-
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
693689
};
690+
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
694691

695692
#if defined(_LIBCUDACXX_USE_PRAGMA_MSVC_WARNING)
696693
// MSVC complains about narrowing conversions on these copy constructors regardless if they are used
@@ -1191,6 +1188,7 @@ arg(const complex<_Tp>& __c)
11911188
return _CUDA_VSTD::atan2(__c.imag(), __c.real());
11921189
}
11931190

1191+
#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
11941192
template <class _Tp>
11951193
inline _LIBCUDACXX_INLINE_VISIBILITY
11961194
__enable_if_t<
@@ -1201,6 +1199,7 @@ arg(_Tp __re)
12011199
{
12021200
return _CUDA_VSTD::atan2l(0.L, __re);
12031201
}
1202+
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
12041203

12051204
template<class _Tp>
12061205
inline _LIBCUDACXX_INLINE_VISIBILITY
@@ -1775,6 +1774,57 @@ operator<<(basic_ostream<_CharT, _Traits>& __os, const complex<_Tp>& __x)
17751774
return __os << __s.str();
17761775
}
17771776
#endif // !_LIBCUDACXX_HAS_NO_LOCALIZATION
1777+
#else // ^^^ !__cuda_std__ ^^^ / vvv __cuda_std__
1778+
#ifndef _LIBCUDACXX_COMPILER_NVRTC
1779+
template<typename ValueType,class charT, class traits>
1780+
::std::basic_ostream<charT, traits>& operator<<(::std::basic_ostream<charT, traits>& os, const complex<ValueType>& z)
1781+
{
1782+
os << '(' << z.real() << ',' << z.imag() << ')';
1783+
return os;
1784+
}
1785+
1786+
template<typename ValueType, typename charT, class traits>
1787+
::std::basic_istream<charT, traits>&
1788+
operator>>(::std::basic_istream<charT, traits>& is, complex<ValueType>& z)
1789+
{
1790+
ValueType re, im;
1791+
1792+
charT ch;
1793+
is >> ch;
1794+
1795+
if(ch == '(')
1796+
{
1797+
is >> re >> ch;
1798+
if (ch == ',')
1799+
{
1800+
is >> im >> ch;
1801+
if (ch == ')')
1802+
{
1803+
z = complex<ValueType>(re, im);
1804+
}
1805+
else
1806+
{
1807+
is.setstate(::std::ios_base::failbit);
1808+
}
1809+
}
1810+
else if (ch == ')')
1811+
{
1812+
z = re;
1813+
}
1814+
else
1815+
{
1816+
is.setstate(::std::ios_base::failbit);
1817+
}
1818+
}
1819+
else
1820+
{
1821+
is.putback(ch);
1822+
is >> re;
1823+
z = re;
1824+
}
1825+
return is;
1826+
}
1827+
#endif // _LIBCUDACXX_COMPILER_NVRTC
17781828
#endif // __cuda_std__
17791829

17801830
#if _LIBCUDACXX_STD_VER > 11 && defined(_LIBCUDACXX_HAS_STL_LITERALS)

thrust/testing/complex.cu

+12
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ struct TestComplexSizeAndAlignment
5454
};
5555
SimpleUnitTest<TestComplexSizeAndAlignment, FloatingPointTypes> TestComplexSizeAndAlignmentInstance;
5656

57+
template <typename T>
58+
struct TestComplexTypeCheck
59+
{
60+
void operator()()
61+
{
62+
THRUST_STATIC_ASSERT(thrust::is_complex<thrust::complex<T>>::value);
63+
THRUST_STATIC_ASSERT(thrust::is_complex<std::complex<T>>::value);
64+
THRUST_STATIC_ASSERT(thrust::is_complex<cuda::std::complex<T>>::value);
65+
}
66+
};
67+
SimpleUnitTest<TestComplexTypeCheck, FloatingPointTypes> TestComplexTypeCheckInstance;
68+
5769
template <typename T>
5870
struct TestComplexConstructionAndAssignment
5971
{

thrust/testing/unittest/assertions.h

+6-23
Original file line numberDiff line numberDiff line change
@@ -304,31 +304,14 @@ bool almost_equal(double a, double b, double a_tol, double r_tol)
304304
return true;
305305
}
306306

307-
namespace
308-
{ // anonymous namespace
309-
310-
template <typename>
311-
struct is_complex : public THRUST_NS_QUALIFIER::false_type
312-
{};
313-
314-
template <typename T>
315-
struct is_complex<THRUST_NS_QUALIFIER::complex<T>> : public THRUST_NS_QUALIFIER::true_type
316-
{};
317-
318-
template <typename T>
319-
struct is_complex<std::complex<T>> : public THRUST_NS_QUALIFIER::true_type
320-
{};
321-
322-
} // namespace
323-
324307
template <typename T1, typename T2>
325-
inline
326-
typename THRUST_NS_QUALIFIER::detail::enable_if<is_complex<T1>::value && is_complex<T2>::value,
327-
bool>::type
328-
almost_equal(const T1 &a, const T2 &b, double a_tol, double r_tol)
308+
typename THRUST_NS_QUALIFIER::detail::enable_if<THRUST_NS_QUALIFIER::is_complex<T1>::value &&
309+
THRUST_NS_QUALIFIER::is_complex<T2>::value,
310+
bool>::type
311+
almost_equal(const T1 &a, const T2 &b, double a_tol, double r_tol)
329312
{
330-
return almost_equal(a.real(), b.real(), a_tol, r_tol) &&
331-
almost_equal(a.imag(), b.imag(), a_tol, r_tol);
313+
return almost_equal(a.real(), b.real(), a_tol, r_tol) &&
314+
almost_equal(a.imag(), b.imag(), a_tol, r_tol);
332315
}
333316

334317
template <typename T1, typename T2>

0 commit comments

Comments
 (0)