Skip to content

Commit

Permalink
Properly handle supported types in std::uniform_int_distribution.
Browse files Browse the repository at this point in the history
Fixes #209.
  • Loading branch information
ndryden committed Aug 18, 2023
1 parent 3c08739 commit ae62848
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion test/test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ inline __half operator-(__half const& lhs, __half const& rhs) {

#endif // defined(AL_HAS_ROCM) && defined(AL_HAS_NCCL)

// It turns out that, per the C++ standard, std::uniform_int_distribution
// is only defined for a subset of integer types, and using an unsupported
// type is undefined behavior. In libstdc++, everything seems to just
// work, but libc++ throws a compile error. This works around that.
namespace internal {
template <typename T> struct IsRngIntType : std::false_type {};
template <> struct IsRngIntType<short> : std::true_type {};
template <> struct IsRngIntType<unsigned short> : std::true_type {};
template <> struct IsRngIntType<int> : std::true_type {};
template <> struct IsRngIntType<unsigned int> : std::true_type {};
template <> struct IsRngIntType<long> : std::true_type {};
template <> struct IsRngIntType<unsigned long> : std::true_type {};
template <> struct IsRngIntType<long long> : std::true_type {};
template <> struct IsRngIntType<unsigned long long> : std::true_type {};
}

/** Helper for generating random data. */
template <typename T, typename Generator,
std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
Expand All @@ -133,11 +149,30 @@ T gen_random_val(Generator& g) {
return rng(g);
}
template <typename T, typename Generator,
std::enable_if_t<std::is_integral<T>::value, bool> = true>
std::enable_if_t<std::is_integral<T>::value
&& internal::IsRngIntType<T>::value, bool> = true>
T gen_random_val(Generator& g) {
std::uniform_int_distribution<T> rng;
return rng(g);
}
template <typename T, typename Generator,
std::enable_if_t<std::is_integral<T>::value
&& !internal::IsRngIntType<T>::value
&& std::is_signed<T>::value, bool> = true>
T gen_random_val(Generator& g) {
static_assert(sizeof(T) <= sizeof(short), "Type too large");
std::uniform_int_distribution<short> rng;
return static_cast<T>(rng(g));
}
template <typename T, typename Generator,
std::enable_if_t<std::is_integral<T>::value
&& !internal::IsRngIntType<T>::value
&& std::is_unsigned<T>::value, bool> = true>
T gen_random_val(Generator& g) {
static_assert(sizeof(T) <= sizeof(short), "Type too large");
std::uniform_int_distribution<short> rng;
return static_cast<T>(rng(g));
}

/** Helper for generating random vectors. */
template <typename T>
Expand Down

0 comments on commit ae62848

Please sign in to comment.