Skip to content

Commit 9df2290

Browse files
committed
[FFI] Provide Field Visit bridge so we can do gradual transition
This PR provides functions that adapts old VisitAttrs reflection utilities to use new reflection mechanism when available. These adapter would allow us to gradually transition the object def from old VisitAttrs based mechanism to new mechanism. - For all objects - Replace VisitAttrs with static void RegisterReflection() that registers the fields - Call T::ReflectionDef() in TVM_STATIC_INIT_BLOCK in cc file - For subclass of AttrsNode<T>: subclass AttrsNodeReflAdapter<T> instead - Do the same steps as above and replace TVM_ATTRS - Provide explicit declaration of _type_key and TVM_FFI_DEFINE_FINAL_OBJECT_INFO We will send followup PRs to do the gradual transition. Once all transition is completed, we will remove AttrsVisitor and only go through the new mechanism.
1 parent 437d00a commit 9df2290

File tree

14 files changed

+521
-67
lines changed

14 files changed

+521
-67
lines changed

ffi/include/tvm/ffi/reflection/reflection.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ inline Function GetMethod(std::string_view type_key, const char* method_name) {
404404
*/
405405
template <typename Callback>
406406
inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
407+
using ResultType = decltype(callback(type_info->fields));
408+
static_assert(std::is_same_v<ResultType, void>, "Callback must return void");
407409
// iterate through acenstors in parent to child order
408410
// skip the first one since it is always the root object
409411
for (int i = 1; i < type_info->type_depth; ++i) {
@@ -417,6 +419,34 @@ inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
417419
}
418420
}
419421

422+
/*!
423+
* \brief Visit each field info of the type info and run callback which returns bool for early stop.
424+
*
425+
* \tparam Callback The callback function type, which returns bool for early stop.
426+
*
427+
* \param type_info The type info.
428+
* \param callback_with_early_stop The callback function.
429+
* \return true if any of early stop is triggered.
430+
*
431+
* \note This function calls both the child and parent type info and can be used for searching.
432+
*/
433+
template <typename Callback>
434+
inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info,
435+
Callback callback_with_early_stop) {
436+
// iterate through acenstors in parent to child order
437+
// skip the first one since it is always the root object
438+
for (int i = 1; i < type_info->type_depth; ++i) {
439+
const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
440+
for (int j = 0; j < parent_info->num_fields; ++j) {
441+
if (callback_with_early_stop(parent_info->fields + j)) return true;
442+
}
443+
}
444+
for (int i = 0; i < type_info->num_fields; ++i) {
445+
if (callback_with_early_stop(type_info->fields + i)) return true;
446+
}
447+
return false;
448+
}
449+
420450
} // namespace reflection
421451
} // namespace ffi
422452
} // namespace tvm

include/tvm/ir/attrs.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include <dmlc/common.h>
4848
#include <tvm/ffi/container/map.h>
4949
#include <tvm/ffi/function.h>
50+
#include <tvm/ffi/reflection/reflection.h>
5051
#include <tvm/ir/expr.h>
5152
#include <tvm/node/structural_equal.h>
5253
#include <tvm/node/structural_hash.h>
@@ -970,5 +971,65 @@ inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*
970971
}
971972
}
972973

974+
/*!
975+
* \brief Adapter for AttrsNode with the new reflection API.
976+
*
977+
* We will phaseout the old AttrsNode in future in favor of the new reflection API.
978+
* This adapter allows us to gradually migrate to the new reflection API.
979+
*
980+
* \tparam DerivedType The final attribute type.
981+
*/
982+
template <typename DerivedType>
983+
class AttrsNodeReflAdapter : public BaseAttrsNode {
984+
public:
985+
void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final {
986+
LOG(FATAL) << "`" << DerivedType::_type_key << "` uses new reflection mechanism for init";
987+
}
988+
void VisitNonDefaultAttrs(AttrVisitor* v) final {
989+
LOG(FATAL) << "`" << DerivedType::_type_key
990+
<< "` uses new reflection mechanism for visit non default attrs";
991+
}
992+
void VisitAttrs(AttrVisitor* v) final {
993+
LOG(FATAL) << "`" << DerivedType::_type_key
994+
<< "` uses new reflection mechanism for visit attrs";
995+
}
996+
997+
bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
998+
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex());
999+
bool success = true;
1000+
ffi::reflection::ForEachFieldInfoWithEarlyStop(
1001+
type_info, [&](const TVMFFIFieldInfo* field_info) {
1002+
ffi::reflection::FieldGetter field_getter(field_info);
1003+
ffi::Any field_value = field_getter(self());
1004+
ffi::Any other_field_value = field_getter(other);
1005+
if (!equal.AnyEqual(field_value, other_field_value)) {
1006+
success = false;
1007+
return true;
1008+
}
1009+
return false;
1010+
});
1011+
return success;
1012+
}
1013+
1014+
void SHashReduce(SHashReducer hash_reducer) const {
1015+
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex());
1016+
ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
1017+
ffi::reflection::FieldGetter field_getter(field_info);
1018+
ffi::Any field_value = field_getter(self());
1019+
hash_reducer(field_value);
1020+
});
1021+
}
1022+
1023+
Array<AttrFieldInfo> ListFieldInfo() const final {
1024+
// use the new reflection to list field info
1025+
return Array<AttrFieldInfo>();
1026+
}
1027+
1028+
private:
1029+
DerivedType* self() const {
1030+
return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
1031+
}
1032+
};
1033+
9731034
} // namespace tvm
9741035
#endif // TVM_IR_ATTRS_H_

include/tvm/relax/attrs/ccl.h

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,54 +24,70 @@
2424
#ifndef TVM_RELAX_ATTRS_CCL_H_
2525
#define TVM_RELAX_ATTRS_CCL_H_
2626

27+
#include <tvm/ffi/reflection/reflection.h>
2728
#include <tvm/relax/expr.h>
2829

2930
namespace tvm {
3031
namespace relax {
3132

3233
/*! \brief Attributes used in allreduce operators */
33-
struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
34+
struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter<AllReduceAttrs> {
3435
String op_type;
3536
bool in_group;
3637

37-
TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
38-
TVM_ATTR_FIELD(op_type).describe(
39-
"The type of reduction operation to be applied to the input data. Now only sum is "
40-
"supported.");
41-
TVM_ATTR_FIELD(in_group).describe(
42-
"Whether the reduction operation performs in group or globally or in group as default.");
38+
static void RegisterReflection() {
39+
namespace refl = tvm::ffi::reflection;
40+
refl::ObjectDef<AllReduceAttrs>()
41+
.def_ro("op_type", &AllReduceAttrs::op_type,
42+
"The type of reduction operation to be applied to the input data. Now only sum is "
43+
"supported.")
44+
.def_ro("in_group", &AllReduceAttrs::in_group,
45+
"Whether the reduction operation performs in group or globally or in group as "
46+
"default.");
4347
}
48+
49+
static constexpr const char* _type_key = "relax.attrs.AllReduceAttrs";
50+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllReduceAttrs, BaseAttrsNode);
4451
}; // struct AllReduceAttrs
4552

4653
/*! \brief Attributes used in allgather operators */
47-
struct AllGatherAttrs : public tvm::AttrsNode<AllGatherAttrs> {
54+
struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter<AllGatherAttrs> {
4855
int num_workers;
4956
bool in_group;
5057

51-
TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") {
52-
TVM_ATTR_FIELD(num_workers)
53-
.describe(
54-
"The number of workers, also the number of parts the given buffer should be chunked "
55-
"into.");
56-
TVM_ATTR_FIELD(in_group).describe(
57-
"Whether the allgather operation performs in group or globally or in group as default.");
58+
static void RegisterReflection() {
59+
namespace refl = tvm::ffi::reflection;
60+
refl::ObjectDef<AllGatherAttrs>()
61+
.def_ro("num_workers", &AllGatherAttrs::num_workers,
62+
"The number of workers, also the number of parts the given buffer should be "
63+
"chunked into.")
64+
.def_ro("in_group", &AllGatherAttrs::in_group,
65+
"Whether the allgather operation performs in group or globally or in group as "
66+
"default.");
5867
}
68+
69+
static constexpr const char* _type_key = "relax.attrs.AllGatherAttrs";
70+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllGatherAttrs, BaseAttrsNode);
5971
}; // struct AllGatherAttrs
6072

6173
/*! \brief Attributes used in scatter operators */
6274
struct ScatterCollectiveAttrs : public tvm::AttrsNode<ScatterCollectiveAttrs> {
6375
int num_workers;
6476
int axis;
6577

66-
TVM_DECLARE_ATTRS(ScatterCollectiveAttrs, "relax.attrs.ScatterCollectiveAttrs") {
67-
TVM_ATTR_FIELD(num_workers)
68-
.describe(
69-
"The number of workers, also the number of parts the given buffer should be chunked "
70-
"into.");
71-
TVM_ATTR_FIELD(axis).describe(
72-
"The axis of the tensor to be scattered. The tensor will be chunked along "
73-
"this axis.");
78+
static void RegisterReflection() {
79+
namespace refl = tvm::ffi::reflection;
80+
refl::ObjectDef<ScatterCollectiveAttrs>()
81+
.def_ro("num_workers", &ScatterCollectiveAttrs::num_workers,
82+
"The number of workers, also the number of parts the given buffer should be "
83+
"chunked into.")
84+
.def_ro("axis", &ScatterCollectiveAttrs::axis,
85+
"The axis of the tensor to be scattered. The tensor will be chunked along "
86+
"this axis.");
7487
}
88+
89+
static constexpr const char* _type_key = "relax.attrs.ScatterCollectiveAttrs";
90+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScatterCollectiveAttrs, BaseAttrsNode);
7591
}; // struct ScatterCollectiveAttrs
7692

7793
} // namespace relax

src/contrib/msc/core/ir/graph_builder.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void FuncAttrGetter::VisitExpr_(const CallNode* op) {
5050
if (op->attrs.defined()) {
5151
Map<String, String> attrs;
5252
AttrGetter getter(&attrs);
53-
const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&getter);
53+
getter(op->attrs);
5454
for (const auto& pair : attrs) {
5555
if (attrs_.count(pair.first)) {
5656
int cnt = 1;
@@ -350,7 +350,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional<Expr>& bin
350350
attrs = FuncAttrGetter().GetAttrs(call_node->op);
351351
} else if (call_node->attrs.defined()) {
352352
AttrGetter getter(&attrs);
353-
const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
353+
getter(call_node->attrs);
354354
}
355355
} else if (const auto* const_node = expr.as<ConstantNode>()) {
356356
if (const_node->is_scalar()) {

src/contrib/msc/core/ir/graph_builder.h

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#define TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_
2626

2727
#include <dmlc/json.h>
28+
#include <tvm/ffi/reflection/reflection.h>
2829
#include <tvm/relax/expr.h>
2930
#include <tvm/relax/expr_functor.h>
3031
#include <tvm/runtime/ndarray.h>
@@ -106,14 +107,65 @@ struct MSCRBuildConfig {
106107
}
107108
};
108109

109-
class AttrGetter : public AttrVisitor {
110+
class AttrGetter : private AttrVisitor {
110111
public:
111112
/*!
112113
* \brief Get the attributes as Map<String, String>
113114
* \param attrs the attributes.
114115
*/
115116
explicit AttrGetter(Map<String, String>* attrs) : attrs_(attrs) {}
116117

118+
void operator()(const Attrs& attrs) {
119+
// dispatch between new reflection and old reflection
120+
const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index());
121+
if (attrs_tinfo->extra_info != nullptr) {
122+
tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) {
123+
Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs);
124+
this->VisitAny(String(field_info->name), field_value);
125+
});
126+
} else {
127+
// TODO(tvm-team): remove this once all objects are transitioned to the new reflection
128+
const_cast<BaseAttrsNode*>(attrs.get())->VisitAttrs(this);
129+
}
130+
}
131+
132+
private:
133+
void VisitAny(String key, Any value) {
134+
switch (value.type_index()) {
135+
case kTVMFFINone: {
136+
attrs_->Set(key, "");
137+
break;
138+
}
139+
case kTVMFFIBool: {
140+
attrs_->Set(key, std::to_string(value.cast<bool>()));
141+
break;
142+
}
143+
case kTVMFFIInt: {
144+
attrs_->Set(key, std::to_string(value.cast<int64_t>()));
145+
break;
146+
}
147+
case kTVMFFIFloat: {
148+
attrs_->Set(key, std::to_string(value.cast<double>()));
149+
break;
150+
}
151+
case kTVMFFIDataType: {
152+
attrs_->Set(key, runtime::DLDataTypeToString(value.cast<DLDataType>()));
153+
}
154+
case kTVMFFIStr: {
155+
attrs_->Set(key, value.cast<String>());
156+
break;
157+
}
158+
default: {
159+
if (value.type_index() >= kTVMFFIStaticObjectBegin) {
160+
attrs_->Set(key, StringUtils::ToString(value.cast<ObjectRef>()));
161+
} else {
162+
LOG(FATAL) << "Unsupported type: " << value.type_index();
163+
}
164+
break;
165+
}
166+
}
167+
}
168+
117169
void Visit(const char* key, double* value) final { attrs_->Set(key, std::to_string(*value)); }
118170

119171
void Visit(const char* key, int64_t* value) final { attrs_->Set(key, std::to_string(*value)); }

0 commit comments

Comments
 (0)