Skip to content

Commit ed4c92c

Browse files
authored
[FFI] Introduce GlobalDef for function registration (#18111)
This PR introduces reflection::GlobalDef for function registration, which makes the global function registration API more closely aligned with the new reflection style. We will send followup PRs to transition some of the existing mechanisms to the new one.
1 parent 6620fe2 commit ed4c92c

File tree

8 files changed

+239
-124
lines changed

8 files changed

+239
-124
lines changed

ffi/include/tvm/ffi/reflection/reflection.h

Lines changed: 125 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,20 +118,47 @@ class ReflectionDefBase {
118118
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
119119
}
120120
}
121+
121122
template <typename Class, typename R, typename... Args>
122123
static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...)) {
123-
auto fwrap = [func](const Class* target, Args... params) -> R {
124-
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
125-
};
126-
return ffi::Function::FromTyped(fwrap, name);
124+
static_assert(std::is_base_of_v<ObjectRef, Class> || std::is_base_of_v<Object, Class>,
125+
"Class must be derived from ObjectRef or Object");
126+
if constexpr (std::is_base_of_v<ObjectRef, Class>) {
127+
auto fwrap = [func](Class target, Args... params) -> R {
128+
// call method pointer
129+
return (target.*func)(std::forward<Args>(params)...);
130+
};
131+
return ffi::Function::FromTyped(fwrap, name);
132+
}
133+
134+
if constexpr (std::is_base_of_v<Object, Class>) {
135+
auto fwrap = [func](const Class* target, Args... params) -> R {
136+
// call method pointer
137+
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
138+
};
139+
return ffi::Function::FromTyped(fwrap, name);
140+
}
127141
}
128142

129143
template <typename Class, typename R, typename... Args>
130144
static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...) const) {
131-
auto fwrap = [func](const Class* target, Args... params) -> R {
132-
return (target->*func)(std::forward<Args>(params)...);
133-
};
134-
return ffi::Function::FromTyped(fwrap, name);
145+
static_assert(std::is_base_of_v<ObjectRef, Class> || std::is_base_of_v<Object, Class>,
146+
"Class must be derived from ObjectRef or Object");
147+
if constexpr (std::is_base_of_v<ObjectRef, Class>) {
148+
auto fwrap = [func](const Class target, Args... params) -> R {
149+
// call method pointer
150+
return (target.*func)(std::forward<Args>(params)...);
151+
};
152+
return ffi::Function::FromTyped(fwrap, name);
153+
}
154+
155+
if constexpr (std::is_base_of_v<Object, Class>) {
156+
auto fwrap = [func](const Class* target, Args... params) -> R {
157+
// call method pointer
158+
return (target->*func)(std::forward<Args>(params)...);
159+
};
160+
return ffi::Function::FromTyped(fwrap, name);
161+
}
135162
}
136163

137164
template <typename Class, typename Func>
@@ -140,6 +167,96 @@ class ReflectionDefBase {
140167
}
141168
};
142169

170+
class GlobalDef : public ReflectionDefBase {
171+
public:
172+
/*
173+
* \brief Define a global function.
174+
*
175+
* \tparam Func The function type.
176+
* \tparam Extra The extra arguments.
177+
*
178+
* \param name The name of the function.
179+
* \param func The function to be registered.
180+
* \param extra The extra arguments that can be docstring.
181+
*
182+
* \return The reflection definition.
183+
*/
184+
template <typename Func, typename... Extra>
185+
GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
186+
RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func), std::string(name)),
187+
std::forward<Extra>(extra)...);
188+
return *this;
189+
}
190+
191+
/*
192+
* \brief Define a global function in ffi::PackedArgs format.
193+
*
194+
* \tparam Func The function type.
195+
* \tparam Extra The extra arguments.
196+
*
197+
* \param name The name of the function.
198+
* \param func The function to be registered.
199+
* \param extra The extra arguments that can be docstring.
200+
*
201+
* \return The reflection definition.
202+
*/
203+
template <typename Func, typename... Extra>
204+
GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
205+
RegisterFunc(name, ffi::Function::FromPacked(func), std::forward<Extra>(extra)...);
206+
return *this;
207+
}
208+
209+
/*
210+
* \brief Expose a class method as a global function.
211+
*
212+
* An argument will be added to the first position if the function is not static.
213+
*
214+
* \tparam Class The class type.
215+
* \tparam Func The function type.
216+
*
217+
* \param name The name of the method.
218+
* \param func The function to be registered.
219+
*
220+
* \return The reflection definition.
221+
*/
222+
template <typename Func, typename... Extra>
223+
GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
224+
RegisterFunc(name, GetMethod_(std::string(name), std::forward<Func>(func)),
225+
std::forward<Extra>(extra)...);
226+
return *this;
227+
}
228+
229+
private:
230+
template <typename Func>
231+
static TVM_FFI_INLINE Function GetMethod_(std::string name, Func&& func) {
232+
return ffi::Function::FromTyped(std::forward<Func>(func), name);
233+
}
234+
235+
template <typename Class, typename R, typename... Args>
236+
static TVM_FFI_INLINE Function GetMethod_(std::string name, R (Class::*func)(Args...) const) {
237+
return GetMethod<Class>(std::string(name), func);
238+
}
239+
240+
template <typename Class, typename R, typename... Args>
241+
static TVM_FFI_INLINE Function GetMethod_(std::string name, R (Class::*func)(Args...)) {
242+
return GetMethod<Class>(std::string(name), func);
243+
}
244+
245+
template <typename... Extra>
246+
void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) {
247+
TVMFFIMethodInfo info;
248+
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
249+
info.doc = TVMFFIByteArray{nullptr, 0};
250+
info.type_schema = TVMFFIByteArray{nullptr, 0};
251+
info.flags = 0;
252+
// obtain the method function
253+
info.method = AnyView(func).CopyToTVMFFIAny();
254+
// apply method info traits
255+
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
256+
TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0));
257+
}
258+
};
259+
143260
template <typename Class>
144261
class ObjectDef : public ReflectionDefBase {
145262
public:

ffi/src/ffi/container.cc

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,40 +25,11 @@
2525
#include <tvm/ffi/container/map.h>
2626
#include <tvm/ffi/container/shape.h>
2727
#include <tvm/ffi/function.h>
28+
#include <tvm/ffi/reflection/reflection.h>
2829

2930
namespace tvm {
3031
namespace ffi {
3132

32-
TVM_FFI_REGISTER_GLOBAL("ffi.Array").set_body_packed([](ffi::PackedArgs args, Any* ret) {
33-
*ret = Array<Any>(args.data(), args.data() + args.size());
34-
});
35-
36-
TVM_FFI_REGISTER_GLOBAL("ffi.ArrayGetItem")
37-
.set_body_typed([](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); });
38-
39-
TVM_FFI_REGISTER_GLOBAL("ffi.ArraySize").set_body_typed([](const ffi::ArrayObj* n) -> int64_t {
40-
return static_cast<int64_t>(n->size());
41-
});
42-
// Map
43-
TVM_FFI_REGISTER_GLOBAL("ffi.Map").set_body_packed([](ffi::PackedArgs args, Any* ret) {
44-
TVM_FFI_ICHECK_EQ(args.size() % 2, 0);
45-
Map<Any, Any> data;
46-
for (int i = 0; i < args.size(); i += 2) {
47-
data.Set(args[i], args[i + 1]);
48-
}
49-
*ret = data;
50-
});
51-
52-
TVM_FFI_REGISTER_GLOBAL("ffi.MapSize").set_body_typed([](const ffi::MapObj* n) -> int64_t {
53-
return static_cast<int64_t>(n->size());
54-
});
55-
56-
TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem")
57-
.set_body_typed([](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); });
58-
59-
TVM_FFI_REGISTER_GLOBAL("ffi.MapCount")
60-
.set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); });
61-
6233
// Favor struct outside function scope as MSVC may have bug for in fn scope struct.
6334
class MapForwardIterFunctor {
6435
public:
@@ -86,10 +57,33 @@ class MapForwardIterFunctor {
8657
ffi::MapObj::iterator end_;
8758
};
8859

89-
TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor")
90-
.set_body_typed([](const ffi::MapObj* n) -> ffi::Function {
91-
return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end()));
92-
});
93-
60+
TVM_FFI_STATIC_INIT_BLOCK({
61+
namespace refl = tvm::ffi::reflection;
62+
refl::GlobalDef()
63+
.def_packed("ffi.Array",
64+
[](ffi::PackedArgs args, Any* ret) {
65+
*ret = Array<Any>(args.data(), args.data() + args.size());
66+
})
67+
.def("ffi.ArrayGetItem", [](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); })
68+
.def("ffi.ArraySize",
69+
[](const ffi::ArrayObj* n) -> int64_t { return static_cast<int64_t>(n->size()); })
70+
.def_packed("ffi.Map",
71+
[](ffi::PackedArgs args, Any* ret) {
72+
TVM_FFI_ICHECK_EQ(args.size() % 2, 0);
73+
Map<Any, Any> data;
74+
for (int i = 0; i < args.size(); i += 2) {
75+
data.Set(args[i], args[i + 1]);
76+
}
77+
*ret = data;
78+
})
79+
.def("ffi.MapSize",
80+
[](const ffi::MapObj* n) -> int64_t { return static_cast<int64_t>(n->size()); })
81+
.def("ffi.MapGetItem", [](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); })
82+
.def("ffi.MapCount",
83+
[](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); })
84+
.def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) -> ffi::Function {
85+
return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end()));
86+
});
87+
});
9488
} // namespace ffi
9589
} // namespace tvm

ffi/src/ffi/function.cc

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/ffi/error.h>
2929
#include <tvm/ffi/function.h>
3030
#include <tvm/ffi/memory.h>
31+
#include <tvm/ffi/reflection/reflection.h>
3132
#include <tvm/ffi/string.h>
3233

3334
namespace tvm {
@@ -307,31 +308,30 @@ int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, void* symbol) {
307308
TVM_FFI_SAFE_CALL_END();
308309
}
309310

310-
TVM_FFI_REGISTER_GLOBAL("ffi.FunctionRemoveGlobal")
311-
.set_body_typed([](const tvm::ffi::String& name) -> bool {
312-
return tvm::ffi::GlobalFunctionTable::Global()->Remove(name);
313-
});
314-
315-
TVM_FFI_REGISTER_GLOBAL("ffi.FunctionListGlobalNamesFunctor").set_body_typed([]() {
316-
// NOTE: we return functor instead of array
317-
// so list global function names do not need to depend on array
318-
// this is because list global function names usually is a core api that happens
319-
// before array ffi functions are available.
320-
tvm::ffi::Array<tvm::ffi::String> names = tvm::ffi::GlobalFunctionTable::Global()->ListNames();
321-
auto return_functor = [names](int64_t i) -> tvm::ffi::Any {
322-
if (i < 0) {
323-
return names.size();
324-
} else {
325-
return names[i];
326-
}
327-
};
328-
return tvm::ffi::Function::FromTyped(return_functor);
329-
});
330-
331-
TVM_FFI_REGISTER_GLOBAL("ffi.String").set_body_typed([](tvm::ffi::String val) -> tvm::ffi::String {
332-
return val;
333-
});
334-
335-
TVM_FFI_REGISTER_GLOBAL("ffi.Bytes").set_body_typed([](tvm::ffi::Bytes val) -> tvm::ffi::Bytes {
336-
return val;
311+
TVM_FFI_STATIC_INIT_BLOCK({
312+
namespace refl = tvm::ffi::reflection;
313+
refl::GlobalDef()
314+
.def("ffi.FunctionRemoveGlobal",
315+
[](const tvm::ffi::String& name) -> bool {
316+
return tvm::ffi::GlobalFunctionTable::Global()->Remove(name);
317+
})
318+
.def("ffi.FunctionListGlobalNamesFunctor",
319+
[]() {
320+
// NOTE: we return functor instead of array
321+
// so list global function names do not need to depend on array
322+
// this is because list global function names usually is a core api that happens
323+
// before array ffi functions are available.
324+
tvm::ffi::Array<tvm::ffi::String> names =
325+
tvm::ffi::GlobalFunctionTable::Global()->ListNames();
326+
auto return_functor = [names](int64_t i) -> tvm::ffi::Any {
327+
if (i < 0) {
328+
return names.size();
329+
} else {
330+
return names[i];
331+
}
332+
};
333+
return tvm::ffi::Function::FromTyped(return_functor);
334+
})
335+
.def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return val; })
336+
.def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return val; });
337337
});

ffi/src/ffi/ndarray.cc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,27 @@
2323
#include <tvm/ffi/c_api.h>
2424
#include <tvm/ffi/container/ndarray.h>
2525
#include <tvm/ffi/function.h>
26+
#include <tvm/ffi/reflection/reflection.h>
2627

2728
namespace tvm {
2829
namespace ffi {
2930

30-
// Shape
31-
TVM_FFI_REGISTER_GLOBAL("ffi.Shape").set_body_packed([](ffi::PackedArgs args, Any* ret) {
32-
int64_t* mutable_data;
33-
ObjectPtr<ShapeObj> shape = details::MakeEmptyShape(args.size(), &mutable_data);
34-
for (int i = 0; i < args.size(); ++i) {
35-
if (auto opt_int = args[i].try_cast<int64_t>()) {
36-
mutable_data[i] = *opt_int;
37-
} else {
38-
TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments";
31+
TVM_FFI_STATIC_INIT_BLOCK({
32+
namespace refl = tvm::ffi::reflection;
33+
refl::GlobalDef().def_packed("ffi.Shape", [](ffi::PackedArgs args, Any* ret) {
34+
int64_t* mutable_data;
35+
ObjectPtr<ShapeObj> shape = details::MakeEmptyShape(args.size(), &mutable_data);
36+
for (int i = 0; i < args.size(); ++i) {
37+
if (auto opt_int = args[i].try_cast<int64_t>()) {
38+
mutable_data[i] = *opt_int;
39+
} else {
40+
TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments";
41+
}
3942
}
40-
}
41-
*ret = Shape(shape);
43+
*ret = Shape(shape);
44+
});
4245
});
46+
4347
} // namespace ffi
4448
} // namespace tvm
4549

ffi/src/ffi/object.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <tvm/ffi/container/map.h>
2525
#include <tvm/ffi/error.h>
2626
#include <tvm/ffi/function.h>
27+
#include <tvm/ffi/reflection/reflection.h>
2728
#include <tvm/ffi/string.h>
2829

2930
#include <memory>
@@ -404,7 +405,10 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
404405
*ret = ObjectRef(ptr);
405406
}
406407

407-
TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs);
408+
TVM_FFI_STATIC_INIT_BLOCK({
409+
namespace refl = tvm::ffi::reflection;
410+
refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs", MakeObjectFromPackedArgs);
411+
});
408412

409413
} // namespace ffi
410414
} // namespace tvm

0 commit comments

Comments
 (0)