3434
3535namespace tvm {
3636namespace ffi {
37+ namespace details {
38+ /* !
39+ * \brief Base class for Variant.
40+ *
41+ * \tparam all_storage_object Whether all types are derived from ObjectRef.
42+ */
43+ template <bool all_storage_object = false >
44+ class VariantBase {
45+ public:
46+ TVM_FFI_INLINE bool same_as (const VariantBase<all_storage_object>& other) const {
47+ return data_.same_as (other.data_ );
48+ }
49+
50+ protected:
51+ template <typename T>
52+ explicit VariantBase (T other) : data_(std::move(other)) {}
53+
54+ TVM_FFI_INLINE void SetData (Any other_data) { data_ = std::move (other_data); }
55+
56+ TVM_FFI_INLINE Any MoveToAny () && { return std::move (data_); }
57+
58+ TVM_FFI_INLINE AnyView ToAnyView () const { return data_.operator AnyView (); }
59+
60+ Any data_;
61+ };
62+
63+ // Specialization for all object ref case, backed by ObjectRef.
64+ template <>
65+ class VariantBase <true > : public ObjectRef {
66+ protected:
67+ template <typename T>
68+ explicit VariantBase (const T& other) : ObjectRef(other) {}
69+ template <typename T>
70+ explicit VariantBase (T&& other) : ObjectRef(std::move(other)) {}
71+ explicit VariantBase (ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
72+ explicit VariantBase (Any other)
73+ : ObjectRef(details::AnyUnsafe::MoveFromAnyStorageAfterCheck<ObjectRef>(std::move(other))) {}
74+
75+ TVM_FFI_INLINE void SetData (ObjectPtr<Object> other) { data_ = std::move (other); }
76+
77+ TVM_FFI_INLINE Any MoveToAny () && { return Any (ObjectRef (std::move (data_))); }
78+
79+ TVM_FFI_INLINE AnyView ToAnyView () const {
80+ TVMFFIAny any_data;
81+ if (data_ == nullptr ) {
82+ any_data.type_index = TypeIndex::kTVMFFINone ;
83+ any_data.v_int64 = 0 ;
84+ } else {
85+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY (&any_data);
86+ any_data.type_index = data_->type_index ();
87+ any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
88+ }
89+ return AnyView::CopyFromTVMFFIAny (any_data);
90+ }
91+ };
92+ } // namespace details
3793
3894/* !
3995 * \brief A typed variant container.
4096 *
41- * A Variant is backed by Any container, with strong checks during construction.
97+ * When all values are ObjectRef, Variant is backed by ObjectRef,
98+ * otherwise it is backed by Any.
4299 */
43100template <typename ... V>
44- class Variant {
101+ class Variant : public details ::VariantBase<details::all_object_ref_v<V...>> {
45102 public:
103+ using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
46104 static_assert (details::all_storage_enabled_v<V...>,
47105 " All types used in Variant<...> must be compatible with Any" );
48106 /*
@@ -54,31 +112,30 @@ class Variant {
54112 template <typename T>
55113 using enable_if_variant_contains_t = std::enable_if_t <variant_contains_v<T>>;
56114
57- Variant (const Variant<V...>& other) : data_ (other.data_) {}
58- Variant (Variant<V...>&& other) : data_ (std::move(other.data_)) {}
115+ Variant (const Variant<V...>& other) : TParent (other.data_) {}
116+ Variant (Variant<V...>&& other) : TParent (std::move(other.data_)) {}
59117
60118 TVM_FFI_INLINE Variant& operator =(const Variant<V...>& other) {
61- data_ = other.data_ ;
119+ this -> SetData ( other.data_ ) ;
62120 return *this ;
63121 }
64122
65123 TVM_FFI_INLINE Variant& operator =(Variant<V...>&& other) {
66- data_ = std::move (other.data_ );
124+ this -> SetData ( std::move (other.data_ ) );
67125 return *this ;
68126 }
69127
70128 template <typename T, typename = enable_if_variant_contains_t <T>>
71- Variant (T other) : data_ (std::move(other)) {} // NOLINT(*)
129+ Variant (T other) : TParent (std::move(other)) {} // NOLINT(*)
72130
73131 template <typename T, typename = enable_if_variant_contains_t <T>>
74132 TVM_FFI_INLINE Variant& operator =(T other) {
75- data_ = std::move (other);
76- return *this ;
133+ return operator =(Variant (std::move (other)));
77134 }
78135
79136 template <typename T, typename = enable_if_variant_contains_t <T>>
80137 TVM_FFI_INLINE std::optional<T> as () const {
81- return data_. as <T>();
138+ return this -> TParent ::ToAnyView (). template as <T>();
82139 }
83140
84141 /*
@@ -89,29 +146,27 @@ class Variant {
89146 */
90147 template <typename T, typename = std::enable_if_t <std::is_base_of_v<Object, T>>>
91148 TVM_FFI_INLINE const T* as () const {
92- return data_. as <const T*>().value_or (nullptr );
149+ return this -> TParent ::ToAnyView (). template as <const T*>().value_or (nullptr );
93150 }
94151
95152 template <typename T, typename = enable_if_variant_contains_t <T>>
96153 TVM_FFI_INLINE T get () const & {
97- return data_ .template cast <T>();
154+ return this -> TParent ::ToAnyView () .template cast <T>();
98155 }
99156
100157 template <typename T, typename = enable_if_variant_contains_t <T>>
101158 TVM_FFI_INLINE T get () && {
102- return std::move (data_ ).template cast <T>();
159+ return std::move (* this ). TParent ::MoveToAny ( ).template cast <T>();
103160 }
104161
105- TVM_FFI_INLINE std::string GetTypeKey () const { return data_ .GetTypeKey (); }
162+ TVM_FFI_INLINE std::string GetTypeKey () const { return this -> TParent ::ToAnyView () .GetTypeKey (); }
106163
107164 private:
108165 friend struct TypeTraits <Variant<V...>>;
109166 friend struct ObjectPtrHash ;
110167 friend struct ObjectPtrEqual ;
111168 // constructor from any
112- explicit Variant (Any data) : data_(std::move(data)) {}
113- // internal data is backed by Any
114- Any data_;
169+ explicit Variant (Any data) : TParent(std::move(data)) {}
115170 /* !
116171 * \brief Get the object pointer from the variant
117172 * \note This function is only available if all types used in Variant<...> are derived from
@@ -122,8 +177,11 @@ class Variant {
122177 static_assert (all_object_v,
123178 " All types used in Variant<...> must be derived from ObjectRef "
124179 " to enable ObjectPtrHash/ObjectPtrEqual" );
125- return details::AnyUnsafe::ObjectPtrFromAnyAfterCheck ( data_);
180+ return this -> data_ . get ( );
126181 }
182+ // rexpose to friend class
183+ using TParent::MoveToAny;
184+ using TParent::ToAnyView;
127185};
128186
129187template <typename ... V>
@@ -132,11 +190,11 @@ inline constexpr bool use_default_type_traits_v<Variant<V...>> = false;
132190template <typename ... V>
133191struct TypeTraits <Variant<V...>> : public TypeTraitsBase {
134192 static TVM_FFI_INLINE void CopyToAnyView (const Variant<V...>& src, TVMFFIAny* result) {
135- *result = AnyView ( src.data_ ).CopyToTVMFFIAny ();
193+ *result = src.ToAnyView ( ).CopyToTVMFFIAny ();
136194 }
137195
138196 static TVM_FFI_INLINE void MoveToAny (Variant<V...> src, TVMFFIAny* result) {
139- *result = details::AnyUnsafe::MoveAnyToTVMFFIAny (std::move (src. data_ ));
197+ *result = details::AnyUnsafe::MoveAnyToTVMFFIAny (std::move (src). MoveToAny ( ));
140198 }
141199
142200 static TVM_FFI_INLINE std::string GetMismatchTypeInfo (const TVMFFIAny* src) {
0 commit comments