Skip to content

Commit f0e62eb

Browse files
authored
[REFACTOR] Transition VisitAttrs to new reflection mechanism in tir/ir_builder/meta_schedule (#18096)
This PR transitions VisitAttrs to new reflection mechansim in the following components: tir, ir_builder, meta_schedule
1 parent c5c733c commit f0e62eb

File tree

92 files changed

+1472
-720
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+1472
-720
lines changed

ffi/include/tvm/ffi/reflection/reflection.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ class ObjectDef : public ReflectionDefBase {
162162
*
163163
* \return The reflection definition.
164164
*/
165-
template <typename T, typename... Extra>
166-
TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) {
165+
template <typename T, typename BaseClass, typename... Extra>
166+
TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::*field_ptr, Extra&&... extra) {
167167
RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
168168
return *this;
169169
}
@@ -181,8 +181,8 @@ class ObjectDef : public ReflectionDefBase {
181181
*
182182
* \return The reflection definition.
183183
*/
184-
template <typename T, typename... Extra>
185-
TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) {
184+
template <typename T, typename BaseClass, typename... Extra>
185+
TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::*field_ptr, Extra&&... extra) {
186186
static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields");
187187
RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
188188
return *this;
@@ -239,9 +239,10 @@ class ObjectDef : public ReflectionDefBase {
239239
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info));
240240
}
241241

242-
template <typename T, typename... ExtraArgs>
243-
void RegisterField(const char* name, T Class::*field_ptr, bool writable,
242+
template <typename T, typename BaseClass, typename... ExtraArgs>
243+
void RegisterField(const char* name, T BaseClass::*field_ptr, bool writable,
244244
ExtraArgs&&... extra_args) {
245+
static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a base class of Class");
245246
TVMFFIFieldInfo info;
246247
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
247248
info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;

ffi/include/tvm/ffi/type_traits.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,46 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
274274
static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; }
275275
};
276276

277+
// Enum Integer POD values
278+
template <typename IntEnum>
279+
struct TypeTraits<IntEnum, std::enable_if_t<std::is_enum_v<IntEnum> &&
280+
std::is_integral_v<std::underlying_type_t<IntEnum>>>>
281+
: public TypeTraitsBase {
282+
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt;
283+
284+
static TVM_FFI_INLINE void CopyToAnyView(const IntEnum& src, TVMFFIAny* result) {
285+
result->type_index = TypeIndex::kTVMFFIInt;
286+
result->v_int64 = static_cast<int64_t>(src);
287+
}
288+
289+
static TVM_FFI_INLINE void MoveToAny(IntEnum src, TVMFFIAny* result) {
290+
CopyToAnyView(src, result);
291+
}
292+
293+
static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
294+
// NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny
295+
return src->type_index == TypeIndex::kTVMFFIInt;
296+
}
297+
298+
static TVM_FFI_INLINE IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
299+
return static_cast<IntEnum>(src->v_int64);
300+
}
301+
302+
static TVM_FFI_INLINE IntEnum MoveFromAnyAfterCheck(TVMFFIAny* src) {
303+
// POD type, we can just copy the value
304+
return CopyFromAnyViewAfterCheck(src);
305+
}
306+
307+
static TVM_FFI_INLINE std::optional<IntEnum> TryCastFromAnyView(const TVMFFIAny* src) {
308+
if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) {
309+
return static_cast<IntEnum>(src->v_int64);
310+
}
311+
return std::nullopt;
312+
}
313+
314+
static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; }
315+
};
316+
277317
// Float POD values
278318
template <typename Float>
279319
struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>>

ffi/tests/cpp/test_any.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,24 @@ TEST(Any, Int) {
6060
EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2);
6161
}
6262

63+
TEST(Any, Enum) {
64+
enum class ENum : int {
65+
A = 1,
66+
B = 2,
67+
};
68+
69+
AnyView view0;
70+
Optional<ENum> opt_v0 = view0.as<ENum>();
71+
EXPECT_TRUE(!opt_v0.has_value());
72+
73+
AnyView view1 = ENum::A;
74+
EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt);
75+
EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1);
76+
77+
ENum v1 = view1.cast<ENum>();
78+
EXPECT_EQ(v1, ENum::A);
79+
}
80+
6381
TEST(Any, bool) {
6482
AnyView view0;
6583
Optional<bool> opt_v0 = view0.as<bool>();

include/tvm/arith/analyzer.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#define TVM_ARITH_ANALYZER_H_
2626

2727
#include <tvm/arith/int_set.h>
28+
#include <tvm/ffi/reflection/reflection.h>
2829
#include <tvm/ir/expr.h>
2930
#include <tvm/support/with.h>
3031

@@ -86,11 +87,15 @@ class ConstIntBoundNode : public Object {
8687
int64_t min_value;
8788
int64_t max_value;
8889

89-
void VisitAttrs(tvm::AttrVisitor* v) {
90-
v->Visit("min_value", &min_value);
91-
v->Visit("max_value", &max_value);
90+
static void RegisterReflection() {
91+
namespace refl = tvm::ffi::reflection;
92+
refl::ObjectDef<ConstIntBoundNode>()
93+
.def_ro("min_value", &ConstIntBoundNode::min_value)
94+
.def_ro("max_value", &ConstIntBoundNode::max_value);
9295
}
9396

97+
static constexpr bool _type_has_method_visit_attrs = false;
98+
9499
bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
95100
return equal(min_value, other->min_value) && equal(max_value, other->max_value);
96101
}
@@ -208,11 +213,15 @@ class ModularSetNode : public Object {
208213
/*! \brief The base */
209214
int64_t base;
210215

211-
void VisitAttrs(tvm::AttrVisitor* v) {
212-
v->Visit("coeff", &coeff);
213-
v->Visit("base", &base);
216+
static void RegisterReflection() {
217+
namespace refl = tvm::ffi::reflection;
218+
refl::ObjectDef<ModularSetNode>()
219+
.def_ro("coeff", &ModularSetNode::coeff)
220+
.def_ro("base", &ModularSetNode::base);
214221
}
215222

223+
static constexpr bool _type_has_method_visit_attrs = false;
224+
216225
bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
217226
return equal(coeff, other->coeff) && equal(base, other->base);
218227
}

include/tvm/arith/int_solver.h

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,13 @@ class IntGroupBoundsNode : public Object {
6262
Array<PrimExpr> equal;
6363
Array<PrimExpr> upper;
6464

65-
void VisitAttrs(tvm::AttrVisitor* v) {
66-
v->Visit("coef", &coef);
67-
v->Visit("lower", &lower);
68-
v->Visit("equal", &equal);
69-
v->Visit("upper", &upper);
65+
static void RegisterReflection() {
66+
namespace refl = tvm::ffi::reflection;
67+
refl::ObjectDef<IntGroupBoundsNode>()
68+
.def_ro("coef", &IntGroupBoundsNode::coef)
69+
.def_ro("lower", &IntGroupBoundsNode::lower)
70+
.def_ro("equal", &IntGroupBoundsNode::equal)
71+
.def_ro("upper", &IntGroupBoundsNode::upper);
7072
}
7173

7274
bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const {
@@ -81,6 +83,7 @@ class IntGroupBoundsNode : public Object {
8183
hash_reduce(upper);
8284
}
8385

86+
static constexpr const bool _type_has_method_visit_attrs = false;
8487
static constexpr const bool _type_has_method_sequal_reduce = true;
8588
static constexpr const char* _type_key = "arith.IntGroupBounds";
8689
TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object);
@@ -152,10 +155,12 @@ class IntConstraintsNode : public Object {
152155
// e.g., A \alpha = \beta or A \alpha <= \beta
153156
Array<PrimExpr> relations;
154157

155-
void VisitAttrs(tvm::AttrVisitor* v) {
156-
v->Visit("variables", &variables);
157-
v->Visit("ranges", &ranges);
158-
v->Visit("relations", &relations);
158+
static void RegisterReflection() {
159+
namespace refl = tvm::ffi::reflection;
160+
refl::ObjectDef<IntConstraintsNode>()
161+
.def_ro("variables", &IntConstraintsNode::variables)
162+
.def_ro("ranges", &IntConstraintsNode::ranges)
163+
.def_ro("relations", &IntConstraintsNode::relations);
159164
}
160165

161166
bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
@@ -169,6 +174,7 @@ class IntConstraintsNode : public Object {
169174
hash_reduce(relations);
170175
}
171176

177+
static constexpr const bool _type_has_method_visit_attrs = false;
172178
static constexpr const bool _type_has_method_sequal_reduce = true;
173179
static constexpr const char* _type_key = "arith.IntConstraints";
174180
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
@@ -213,11 +219,13 @@ class IntConstraintsTransformNode : public Object {
213219
Map<Var, PrimExpr> src_to_dst;
214220
Map<Var, PrimExpr> dst_to_src;
215221

216-
void VisitAttrs(tvm::AttrVisitor* v) {
217-
v->Visit("src", &src);
218-
v->Visit("dst", &dst);
219-
v->Visit("src_to_dst", &src_to_dst);
220-
v->Visit("dst_to_src", &dst_to_src);
222+
static void RegisterReflection() {
223+
namespace refl = tvm::ffi::reflection;
224+
refl::ObjectDef<IntConstraintsTransformNode>()
225+
.def_ro("src", &IntConstraintsTransformNode::src)
226+
.def_ro("dst", &IntConstraintsTransformNode::dst)
227+
.def_ro("src_to_dst", &IntConstraintsTransformNode::src_to_dst)
228+
.def_ro("dst_to_src", &IntConstraintsTransformNode::dst_to_src);
221229
}
222230

223231
bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
@@ -232,6 +240,7 @@ class IntConstraintsTransformNode : public Object {
232240
hash_reduce(dst_to_src);
233241
}
234242

243+
static constexpr const bool _type_has_method_visit_attrs = false;
235244
static constexpr const bool _type_has_method_sequal_reduce = true;
236245
static constexpr const char* _type_key = "arith.IntConstraintsTransform";
237246
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);

include/tvm/arith/iter_affine_map.h

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#define TVM_ARITH_ITER_AFFINE_MAP_H_
5050

5151
#include <tvm/arith/analyzer.h>
52+
#include <tvm/ffi/reflection/reflection.h>
5253
#include <tvm/ir/diagnostic.h>
5354
#include <tvm/ir/expr.h>
5455
#include <tvm/tir/var.h>
@@ -65,9 +66,7 @@ namespace arith {
6566
*/
6667
class IterMapExprNode : public PrimExprNode {
6768
public:
68-
// overrides
69-
void VisitAttrs(tvm::AttrVisitor* v) {}
70-
69+
static constexpr bool _type_has_method_visit_attrs = false;
7170
static constexpr const char* _type_key = "arith.IterMapExpr";
7271
static constexpr const uint32_t _type_child_slots = 2;
7372
TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
@@ -100,12 +99,15 @@ class IterMarkNode : public Object {
10099
*/
101100
PrimExpr extent;
102101

103-
// overrides
104-
void VisitAttrs(tvm::AttrVisitor* v) {
105-
v->Visit("source", &source);
106-
v->Visit("extent", &extent);
102+
static void RegisterReflection() {
103+
namespace refl = tvm::ffi::reflection;
104+
refl::ObjectDef<IterMarkNode>()
105+
.def_ro("source", &IterMarkNode::source)
106+
.def_ro("extent", &IterMarkNode::extent);
107107
}
108108

109+
static constexpr bool _type_has_method_visit_attrs = false;
110+
109111
bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const {
110112
equal->MarkGraphNode();
111113
return equal(source, other->source) && equal(extent, other->extent);
@@ -156,14 +158,17 @@ class IterSplitExprNode : public IterMapExprNode {
156158
/*! \brief Additional scale. */
157159
PrimExpr scale;
158160

159-
// overrides
160-
void VisitAttrs(tvm::AttrVisitor* v) {
161-
v->Visit("source", &source);
162-
v->Visit("lower_factor", &lower_factor);
163-
v->Visit("extent", &extent);
164-
v->Visit("scale", &scale);
161+
static void RegisterReflection() {
162+
namespace refl = tvm::ffi::reflection;
163+
refl::ObjectDef<IterSplitExprNode>()
164+
.def_ro("source", &IterSplitExprNode::source)
165+
.def_ro("lower_factor", &IterSplitExprNode::lower_factor)
166+
.def_ro("extent", &IterSplitExprNode::extent)
167+
.def_ro("scale", &IterSplitExprNode::scale);
165168
}
166169

170+
static constexpr bool _type_has_method_visit_attrs = false;
171+
167172
bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const {
168173
return equal(source, other->source) && equal(lower_factor, other->lower_factor) &&
169174
equal(extent, other->extent) && equal(scale, other->scale);
@@ -223,12 +228,15 @@ class IterSumExprNode : public IterMapExprNode {
223228
/*! \brief The base offset. */
224229
PrimExpr base;
225230

226-
// overrides
227-
void VisitAttrs(tvm::AttrVisitor* v) {
228-
v->Visit("args", &args);
229-
v->Visit("base", &base);
231+
static void RegisterReflection() {
232+
namespace refl = tvm::ffi::reflection;
233+
refl::ObjectDef<IterSumExprNode>()
234+
.def_ro("args", &IterSumExprNode::args)
235+
.def_ro("base", &IterSumExprNode::base);
230236
}
231237

238+
static constexpr bool _type_has_method_visit_attrs = false;
239+
232240
bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const {
233241
return equal(args, other->args) && equal(base, other->base);
234242
}
@@ -291,13 +299,16 @@ class IterMapResultNode : public Object {
291299
*/
292300
PrimExpr padding_predicate;
293301

294-
// overrides
295-
void VisitAttrs(tvm::AttrVisitor* v) {
296-
v->Visit("errors", &errors);
297-
v->Visit("indices", &indices);
298-
v->Visit("padding_predicate", &padding_predicate);
302+
static void RegisterReflection() {
303+
namespace refl = tvm::ffi::reflection;
304+
refl::ObjectDef<IterMapResultNode>()
305+
.def_ro("indices", &IterMapResultNode::indices)
306+
.def_ro("errors", &IterMapResultNode::errors)
307+
.def_ro("padding_predicate", &IterMapResultNode::padding_predicate);
299308
}
300309

310+
static constexpr bool _type_has_method_visit_attrs = false;
311+
301312
static constexpr const char* _type_key = "arith.IterMapResult";
302313
TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object);
303314
};

0 commit comments

Comments
 (0)