Skip to content

Commit 296e2f7

Browse files
committed
[FFI] Variant specialize for all ObjectRef (apache#17943)
1 parent 16e9f0a commit 296e2f7

File tree

7 files changed

+127
-24
lines changed

7 files changed

+127
-24
lines changed

include/tvm/ffi/any.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,15 @@ struct AnyUnsafe : public ObjectUnsafe {
425425
}
426426
}
427427

428+
template <typename T>
429+
static TVM_FFI_INLINE T MoveFromAnyStorageAfterCheck(Any&& ref) {
430+
if constexpr (!std::is_same_v<T, Any>) {
431+
return TypeTraits<T>::MoveFromAnyStorageAfterCheck(&(ref.data_));
432+
} else {
433+
return std::move(ref);
434+
}
435+
}
436+
428437
static TVM_FFI_INLINE Object* ObjectPtrFromAnyAfterCheck(const Any& ref) {
429438
return reinterpret_cast<Object*>(ref.data_.v_obj);
430439
}

include/tvm/ffi/base_details.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@
123123
* This macro is used to clear the padding parts for hash and equality check
124124
* in 32bit platform.
125125
*/
126-
#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \
127-
if constexpr (sizeof(result->v_obj) != sizeof(result->v_int64)) { \
128-
result->v_int64 = 0; \
126+
#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \
127+
if constexpr (sizeof((result)->v_obj) != sizeof((result)->v_int64)) { \
128+
(result)->v_int64 = 0; \
129129
}
130130

131131
namespace tvm {

include/tvm/ffi/container/container_details.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,14 @@ inline constexpr bool storage_enabled_v = std::is_same_v<T, Any> || TypeTraits<T
284284
template <typename... T>
285285
inline constexpr bool all_storage_enabled_v = (storage_enabled_v<T> && ...);
286286

287+
/*!
288+
* \brief Check if all T are compatible with Any.
289+
*
290+
* \tparam T The type to check.
291+
* \return True if T is compatible with Any, false otherwise.
292+
*/
293+
template <typename... T>
294+
inline constexpr bool all_object_ref_v = (std::is_base_of_v<ObjectRef, T> && ...);
287295
/**
288296
* \brief Check if Any storage of Derived can always be directly used as Base.
289297
*

include/tvm/ffi/container/variant.h

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,73 @@
3434

3535
namespace tvm {
3636
namespace ffi {
37+
namespace details {
38+
/*!
39+
* \brief Base class for Variant.
40+
*
41+
* \tparam all_storage_object Whether all types are derived from ObjectRef.
42+
*/
43+
template <bool all_storage_object = false>
44+
class VariantBase {
45+
public:
46+
TVM_FFI_INLINE bool same_as(const VariantBase<all_storage_object>& other) const {
47+
return data_.same_as(other.data_);
48+
}
49+
50+
protected:
51+
template <typename T>
52+
explicit VariantBase(T other) : data_(std::move(other)) {}
53+
54+
TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); }
55+
56+
TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); }
57+
58+
TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); }
59+
60+
Any data_;
61+
};
62+
63+
// Specialization for all object ref case, backed by ObjectRef.
64+
template <>
65+
class VariantBase<true> : public ObjectRef {
66+
protected:
67+
template <typename T>
68+
explicit VariantBase(const T& other) : ObjectRef(other) {}
69+
template <typename T>
70+
explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {}
71+
explicit VariantBase(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
72+
explicit VariantBase(Any other)
73+
: ObjectRef(details::AnyUnsafe::MoveFromAnyStorageAfterCheck<ObjectRef>(std::move(other))) {}
74+
75+
TVM_FFI_INLINE void SetData(ObjectPtr<Object> other) { data_ = std::move(other); }
76+
77+
TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); }
78+
79+
TVM_FFI_INLINE AnyView ToAnyView() const {
80+
TVMFFIAny any_data;
81+
if (data_ == nullptr) {
82+
any_data.type_index = TypeIndex::kTVMFFINone;
83+
any_data.v_int64 = 0;
84+
} else {
85+
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
86+
any_data.type_index = data_->type_index();
87+
any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
88+
}
89+
return AnyView::CopyFromTVMFFIAny(any_data);
90+
}
91+
};
92+
} // namespace details
3793

3894
/*!
3995
* \brief A typed variant container.
4096
*
41-
* A Variant is backed by Any container, with strong checks during construction.
97+
* When all values are ObjectRef, Variant is backed by ObjectRef,
98+
* otherwise it is backed by Any.
4299
*/
43100
template <typename... V>
44-
class Variant {
101+
class Variant : public details::VariantBase<details::all_object_ref_v<V...>> {
45102
public:
103+
using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
46104
static_assert(details::all_storage_enabled_v<V...>,
47105
"All types used in Variant<...> must be compatible with Any");
48106
/*
@@ -54,31 +112,30 @@ class Variant {
54112
template <typename T>
55113
using enable_if_variant_contains_t = std::enable_if_t<variant_contains_v<T>>;
56114

57-
Variant(const Variant<V...>& other) : data_(other.data_) {}
58-
Variant(Variant<V...>&& other) : data_(std::move(other.data_)) {}
115+
Variant(const Variant<V...>& other) : TParent(other.data_) {}
116+
Variant(Variant<V...>&& other) : TParent(std::move(other.data_)) {}
59117

60118
TVM_FFI_INLINE Variant& operator=(const Variant<V...>& other) {
61-
data_ = other.data_;
119+
this->SetData(other.data_);
62120
return *this;
63121
}
64122

65123
TVM_FFI_INLINE Variant& operator=(Variant<V...>&& other) {
66-
data_ = std::move(other.data_);
124+
this->SetData(std::move(other.data_));
67125
return *this;
68126
}
69127

70128
template <typename T, typename = enable_if_variant_contains_t<T>>
71-
Variant(T other) : data_(std::move(other)) {} // NOLINT(*)
129+
Variant(T other) : TParent(std::move(other)) {} // NOLINT(*)
72130

73131
template <typename T, typename = enable_if_variant_contains_t<T>>
74132
TVM_FFI_INLINE Variant& operator=(T other) {
75-
data_ = std::move(other);
76-
return *this;
133+
return operator=(Variant(std::move(other)));
77134
}
78135

79136
template <typename T, typename = enable_if_variant_contains_t<T>>
80137
TVM_FFI_INLINE std::optional<T> as() const {
81-
return data_.as<T>();
138+
return this->TParent::ToAnyView().template as<T>();
82139
}
83140

84141
/*
@@ -89,29 +146,27 @@ class Variant {
89146
*/
90147
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
91148
TVM_FFI_INLINE const T* as() const {
92-
return data_.as<const T*>().value_or(nullptr);
149+
return this->TParent::ToAnyView().template as<const T*>().value_or(nullptr);
93150
}
94151

95152
template <typename T, typename = enable_if_variant_contains_t<T>>
96153
TVM_FFI_INLINE T get() const& {
97-
return data_.template cast<T>();
154+
return this->TParent::ToAnyView().template cast<T>();
98155
}
99156

100157
template <typename T, typename = enable_if_variant_contains_t<T>>
101158
TVM_FFI_INLINE T get() && {
102-
return std::move(data_).template cast<T>();
159+
return std::move(*this).TParent::MoveToAny().template cast<T>();
103160
}
104161

105-
TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); }
162+
TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); }
106163

107164
private:
108165
friend struct TypeTraits<Variant<V...>>;
109166
friend struct ObjectPtrHash;
110167
friend struct ObjectPtrEqual;
111168
// constructor from any
112-
explicit Variant(Any data) : data_(std::move(data)) {}
113-
// internal data is backed by Any
114-
Any data_;
169+
explicit Variant(Any data) : TParent(std::move(data)) {}
115170
/*!
116171
* \brief Get the object pointer from the variant
117172
* \note This function is only available if all types used in Variant<...> are derived from
@@ -122,8 +177,11 @@ class Variant {
122177
static_assert(all_object_v,
123178
"All types used in Variant<...> must be derived from ObjectRef "
124179
"to enable ObjectPtrHash/ObjectPtrEqual");
125-
return details::AnyUnsafe::ObjectPtrFromAnyAfterCheck(data_);
180+
return this->data_.get();
126181
}
182+
// rexpose to friend class
183+
using TParent::MoveToAny;
184+
using TParent::ToAnyView;
127185
};
128186

129187
template <typename... V>
@@ -132,11 +190,11 @@ inline constexpr bool use_default_type_traits_v<Variant<V...>> = false;
132190
template <typename... V>
133191
struct TypeTraits<Variant<V...>> : public TypeTraitsBase {
134192
static TVM_FFI_INLINE void CopyToAnyView(const Variant<V...>& src, TVMFFIAny* result) {
135-
*result = AnyView(src.data_).CopyToTVMFFIAny();
193+
*result = src.ToAnyView().CopyToTVMFFIAny();
136194
}
137195

138196
static TVM_FFI_INLINE void MoveToAny(Variant<V...> src, TVMFFIAny* result) {
139-
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
197+
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny());
140198
}
141199

142200
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {

tests/cpp/test_any.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ TEST(Any, ObjectMove) {
337337
auto v0 = std::move(any1).cast<TPrimExpr>();
338338
EXPECT_EQ(v0->value, 3.14);
339339
EXPECT_EQ(v0.use_count(), 1);
340+
EXPECT_TRUE(any1 == nullptr);
340341
}
341342

342343
} // namespace

tests/cpp/test_map.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ TEST(Map, AnyConvertCheck) {
243243
::tvm::ffi::Error);
244244
}
245245

246-
TEST(Map, ffi::FunctionGetItem) {
246+
TEST(Map, FunctionGetItem) {
247247
Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any { return n->at(k); },
248248
"map_get_item");
249249
Map<String, int64_t> map{{"x", 1}, {"y", 2}};

tests/cpp/test_variant.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,31 @@ TEST(Variant, Upcast) {
134134
EXPECT_EQ(a1[0].get<int>(), 1);
135135
}
136136

137+
TEST(Variant, AllObjectRef) {
138+
Variant<TInt, Array<TInt>> v0 = TInt(1);
139+
EXPECT_EQ(v0.get<TInt>()->value, 1);
140+
static_assert(std::is_base_of_v<ObjectRef, decltype(v0)>);
141+
Any any0 = v0;
142+
EXPECT_EQ(any0.cast<TInt>()->value, 1);
143+
auto v2 = any0.cast<Variant<TInt, Array<TInt>>>();
144+
EXPECT_TRUE(v0.same_as(v2));
145+
// assignment operator
146+
v0 = Array<TInt>({TInt(2), TInt(3)});
147+
EXPECT_EQ(v0.get<Array<TInt>>().size(), 2);
148+
EXPECT_EQ(v0.get<Array<TInt>>()[0]->value, 2);
149+
EXPECT_EQ(v0.get<Array<TInt>>()[1]->value, 3);
150+
EXPECT_EQ(sizeof(v0), sizeof(ObjectRef));
151+
}
152+
153+
TEST(Variant, PODSameAs) {
154+
Variant<String, int> v0 = 1;
155+
Variant<String, int> v1 = 1;
156+
EXPECT_TRUE(v0.same_as(v1));
157+
String s = String("hello");
158+
v0 = s;
159+
v1 = s;
160+
EXPECT_TRUE(v0.same_as(v1));
161+
v1 = String("hello");
162+
EXPECT_TRUE(!v0.same_as(v1));
163+
}
137164
} // namespace

0 commit comments

Comments
 (0)