Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 66 additions & 19 deletions rocprim/include/rocprim/detail/match_result_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,85 @@
#include <type_traits>

#include "../config.hpp"
#include "../types/tuple.hpp"

BEGIN_ROCPRIM_NAMESPACE
namespace detail
{

// tuple_contains_type::value is false if Tuple is not rocprim::tuple<> or Tuple is
// rocprim::tuple<> class which does not contain element of type T; otherwise it's true.
template<class T, class Tuple>
struct tuple_contains_type : std::false_type {};
// invoke_result is based on https://en.cppreference.com/w/cpp/types/result_of
// The main difference is using ROCPRIM_HOST_DEVICE, this allows to
// use invoke_result with device-only lambdas/functors in host-only functions
// on HIP-clang.

template <class T>
struct is_reference_wrapper : std::false_type {};
template <class U>
struct is_reference_wrapper<std::reference_wrapper<U>> : std::true_type {};

template<class T>
struct tuple_contains_type<T, ::rocprim::tuple<>> : std::false_type {};
struct invoke_impl {
template<class F, class... Args>
ROCPRIM_HOST_DEVICE
static auto call(F&& f, Args&&... args)
-> decltype(std::forward<F>(f)(std::forward<Args>(args)...));
};

template<class B, class MT>
struct invoke_impl<MT B::*>
{
template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<std::is_base_of<B, Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> T&&;

template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<is_reference_wrapper<Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> decltype(t.get());

template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<!std::is_base_of<B, Td>::value>::type,
class = typename std::enable_if<!is_reference_wrapper<Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> decltype(*std::forward<T>(t));

template<class T, class U, class... Ts>
struct tuple_contains_type<T, ::rocprim::tuple<U, Ts...>> : tuple_contains_type<T, ::rocprim::tuple<Ts...>> {};
template<class T, class... Args, class MT1,
class = typename std::enable_if<std::is_function<MT1>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto call(MT1 B::*pmf, T&& t, Args&&... args)
-> decltype((invoke_impl::get(std::forward<T>(t)).*pmf)(std::forward<Args>(args)...));

template<class T>
ROCPRIM_HOST_DEVICE
static auto call(MT B::*pmd, T&& t)
-> decltype(invoke_impl::get(std::forward<T>(t)).*pmd);
};

template<class F, class... Args, class Fd = typename std::decay<F>::type>
ROCPRIM_HOST_DEVICE
auto INVOKE(F&& f, Args&&... args)
-> decltype(invoke_impl<Fd>::call(std::forward<F>(f), std::forward<Args>(args)...));

// Conforming C++14 implementation (is also a valid C++11 implementation):
template <typename AlwaysVoid, typename, typename...>
struct invoke_result_impl { };
template <typename F, typename...Args>
struct invoke_result_impl<decltype(void(INVOKE(std::declval<F>(), std::declval<Args>()...))), F, Args...>
{
using type = decltype(INVOKE(std::declval<F>(), std::declval<Args>()...));
};

template<class T, class... Ts>
struct tuple_contains_type<T, ::rocprim::tuple<T, Ts...>> : std::true_type {};
template <class F, class... ArgTypes>
struct invoke_result : invoke_result_impl<void, F, ArgTypes...> {};

template<class InputType, class BinaryFunction>
struct match_result_type
{
private:
#ifdef __cpp_lib_is_invocable
using binary_result_type = typename std::invoke_result<BinaryFunction, InputType, InputType>::type;
#else
using binary_result_type = typename std::result_of<BinaryFunction(InputType, InputType)>::type;
#endif

public:
using type = binary_result_type;
using type = typename invoke_result<BinaryFunction, InputType, InputType>::type;
};

} // end namespace detail
Expand Down
7 changes: 2 additions & 5 deletions rocprim/include/rocprim/device/detail/device_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../detail/match_result_type.hpp"

#include "../../intrinsics.hpp"
#include "../../functional.hpp"
Expand All @@ -44,11 +45,7 @@ namespace detail
template<class T1, class T2, class BinaryFunction>
struct unpack_binary_op
{
#ifdef __cpp_lib_is_invocable
using result_type = typename std::invoke_result<BinaryFunction, T1, T2>::type;
#else
using result_type = typename std::result_of<BinaryFunction(T1, T2)>::type;
#endif
using result_type = typename ::rocprim::detail::invoke_result<BinaryFunction, T1, T2>::type;

ROCPRIM_HOST_DEVICE inline
unpack_binary_op() = default;
Expand Down
5 changes: 0 additions & 5 deletions rocprim/include/rocprim/device/device_binary_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,7 @@ hipError_t binary_search(void * temporary_storage,
needles, output,
needles_size,
[haystack, haystack_size, search_op, compare_op]
#ifdef __HIP__
// Workaround: hip-clang does not support std::result_of of device-only functions
ROCPRIM_HOST_DEVICE
#else
ROCPRIM_DEVICE
#endif
(const value_type& value)
{
return search_op(haystack, haystack_size, value, compare_op);
Expand Down
5 changes: 0 additions & 5 deletions rocprim/include/rocprim/device/device_scan_by_key.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,7 @@ hipError_t inclusive_scan_by_key(void * temporary_storage,
rocprim::make_transform_iterator(
rocprim::make_counting_iterator<size_t>(0),
[values_input, keys_input, key_compare_op]
#ifdef __HIP__
// Workaround: hip-clang does not support std::result_of of device-only functions
ROCPRIM_HOST_DEVICE
#else
ROCPRIM_DEVICE
#endif
(const size_t i)
{
flag_type flag(true);
Expand Down
5 changes: 0 additions & 5 deletions rocprim/include/rocprim/device/device_segmented_scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,12 +614,7 @@ hipError_t segmented_exclusive_scan(void * temporary_storage,
rocprim::make_transform_iterator(
rocprim::make_counting_iterator<size_t>(0),
[input, head_flags, initial_value_converted, size]
#ifdef __HIP__
// Workaround: hip-clang does not support std::result_of of device-only functions
ROCPRIM_HOST_DEVICE
#else
ROCPRIM_DEVICE
#endif
(const size_t i)
{
flag_type flag(false);
Expand Down
7 changes: 2 additions & 5 deletions rocprim/include/rocprim/device/device_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "../config.hpp"
#include "../detail/various.hpp"
#include "../detail/match_result_type.hpp"
#include "../types/tuple.hpp"
#include "../iterator/zip_iterator.hpp"

Expand Down Expand Up @@ -145,11 +146,7 @@ hipError_t transform(InputIterator input,
bool debug_synchronous = false)
{
using input_type = typename std::iterator_traits<InputIterator>::value_type;
#ifdef __cpp_lib_is_invocable
using result_type = typename std::invoke_result<UnaryFunction, input_type>::type;
#else
using result_type = typename std::result_of<UnaryFunction(input_type)>::type;
#endif
using result_type = typename ::rocprim::detail::invoke_result<UnaryFunction, input_type>::type;

// Get default config if Config is default_config
using config = detail::default_or_custom_config<
Expand Down
10 changes: 2 additions & 8 deletions rocprim/include/rocprim/iterator/transform_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <type_traits>

#include "../config.hpp"
#include "../detail/match_result_type.hpp"

/// \addtogroup iteratormodule
/// @{
Expand All @@ -49,17 +50,10 @@ BEGIN_ROCPRIM_NAMESPACE
template<
class InputIterator,
class UnaryFunction,
#if defined(__cpp_lib_is_invocable) && !defined(DOXYGEN_SHOULD_SKIP_THIS) // C++17
class ValueType =
typename std::invoke_result<
typename ::rocprim::detail::invoke_result<
UnaryFunction, typename std::iterator_traits<InputIterator>::value_type
>::type
#else
class ValueType =
typename std::result_of<
UnaryFunction(typename std::iterator_traits<InputIterator>::value_type)
>::type
#endif
>
class transform_iterator
{
Expand Down