Skip to content

Commit 5a4038f

Browse files
PragmaTwicetqchen
authored andcommitted
[FFI] Replace Arg2Str with a more powerful for_each (apache#18117)
[FFI] Replace Arg2Str with a more powerful for_each
1 parent 84cfbeb commit 5a4038f

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

include/tvm/ffi/base_details.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/ffi/endian.h>
3030

3131
#include <cstddef>
32+
#include <type_traits>
3233
#include <utility>
3334

3435
#if defined(_MSC_VER)
@@ -135,14 +136,32 @@ namespace tvm {
135136
namespace ffi {
136137
namespace details {
137138

139+
// a dependent-name version of false, for static_assert
140+
template <typename>
141+
inline constexpr bool always_false = false;
142+
138143
// for each iterator
139144
struct for_each_dispatcher {
140145
template <typename F, typename... Args, size_t... I>
141146
static void run(std::index_sequence<I...>, const F& f, Args&&... args) { // NOLINT(*)
142-
(f(I, std::forward<Args>(args)), ...);
147+
if constexpr (std::conjunction_v<
148+
std::is_invocable<F, std::integral_constant<size_t, I>, Args>...>) {
149+
(f(std::integral_constant<size_t, I>{}, std::forward<Args>(args)), ...);
150+
} else if constexpr (std::conjunction_v<std::is_invocable<F, size_t, Args>...>) {
151+
(f(I, std::forward<Args>(args)), ...);
152+
} else if constexpr (std::conjunction_v<std::is_invocable<F, Args>...>) {
153+
(f(std::forward<Args>(args)), ...);
154+
} else {
155+
static_assert(always_false<F>, "The function is not invocable with the provided arguments");
156+
}
143157
}
144158
};
145159

160+
// Three kinds of function F are acceptable in `for_each`:
161+
// 1. F(size_t, Arg): argument with its index
162+
// 2. F(Arg): just the argument
163+
// 3. F(std::integral_constant<size_t, I>, Arg): argument with its constexpr index
164+
// The third one can make the index available in template arguments and `if constexpr`.
146165
template <typename F, typename... Args>
147166
void for_each(const F& f, Args&&... args) { // NOLINT(*)
148167
for_each_dispatcher::run(std::index_sequence_for<Args...>{}, f, std::forward<Args>(args)...);

include/tvm/ffi/function_details.h

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,6 @@ namespace tvm {
3636
namespace ffi {
3737
namespace details {
3838

39-
template <typename ArgType>
40-
struct Arg2Str {
41-
template <size_t i>
42-
TVM_FFI_INLINE static void Apply(std::ostream& os) {
43-
using Arg = std::tuple_element_t<i, ArgType>;
44-
if constexpr (i != 0) {
45-
os << ", ";
46-
}
47-
os << i << ": " << Type2Str<Arg>::v();
48-
}
49-
template <size_t... I>
50-
TVM_FFI_INLINE static void Run(std::ostream& os, std::index_sequence<I...>) {
51-
using TExpander = int[];
52-
(void)TExpander{0, (Apply<I>(os), 0)...};
53-
}
54-
};
55-
5639
template <typename T>
5740
static constexpr bool ArgSupported =
5841
(std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, Any> ||
@@ -78,10 +61,16 @@ struct FuncFunctorImpl {
7861
#endif
7962

8063
TVM_FFI_INLINE static std::string Sig() {
81-
using IdxSeq = std::make_index_sequence<sizeof...(Args)>;
8264
std::ostringstream ss;
8365
ss << "(";
84-
Arg2Str<std::tuple<Args...>>::Run(ss, IdxSeq{});
66+
for_each(
67+
[&ss](auto i, const auto& v) {
68+
if constexpr (i() != 0) {
69+
ss << ", ";
70+
}
71+
ss << i() << ": " << v;
72+
},
73+
Type2Str<Args>::v()...);
8574
ss << ") -> " << Type2Str<R>::v();
8675
return ss.str();
8776
}

0 commit comments

Comments
 (0)