Skip to content

Commit 11be832

Browse files
authored
Revert "[FFI][RUNTIME] Introduce runtime boxed types for int/float/bool" (#17252)
Revert "[FFI][RUNTIME] Introduce runtime boxed types for int/float/bool (#16183)" This reverts commit 5f22be4.
1 parent 05e2bc3 commit 11be832

File tree

184 files changed

+1221
-3215
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

184 files changed

+1221
-3215
lines changed

include/tvm/ir/attrs.h

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,7 @@ class DictAttrs : public Attrs {
265265

266266
auto it = node->dict.find(attr_key);
267267
if (it != node->dict.end()) {
268-
// For backwards compatibility, return through TVMRetValue.
269-
// This triggers any automatic conversions registered with
270-
// PackedFuncValueConverter. Importantly, this allows use of
271-
// `GetAttr<Integer>` and `GetAttr<Bool>` for properties that
272-
// are stored internally as `runtime::Box<int64_t>` and
273-
// `runtime::Box<bool>`.
274-
TVMRetValue ret;
275-
ret = (*it).second;
276-
Optional<TObjectRef> obj = ret;
277-
return obj;
268+
return Downcast<Optional<TObjectRef>>((*it).second);
278269
} else {
279270
return default_value;
280271
}
@@ -324,46 +315,6 @@ inline TAttrs AttrsWithDefaultValues() {
324315
return TAttrs(n);
325316
}
326317

327-
/*!
328-
* \brief Copy the DictAttrs, but overrides attributes with the
329-
* entries from \p attrs.
330-
*
331-
* \param attrs The DictAttrs to update
332-
*
333-
* \param new_attrs Key/values attributes to add to \p attrs.
334-
*
335-
* \returns The new DictAttrs with updated attributes.
336-
*/
337-
DictAttrs WithAttrs(DictAttrs attrs, Map<String, ObjectRef> new_attrs);
338-
339-
/*!
340-
* \brief Copy the DictAttrs, but overrides a single attribute.
341-
*
342-
* \param attrs The DictAttrs to update
343-
*
344-
* \param key The update to insert or update.
345-
*
346-
* \param value The new value of the attribute
347-
*
348-
* \returns The new DictAttrs with updated attributes.
349-
*/
350-
DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value);
351-
352-
inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) {
353-
return WithAttr(std::move(attrs), String(key), std::move(value));
354-
}
355-
356-
/*!
357-
* \brief Copy the DictAttrs, but without a specific attribute.
358-
*
359-
* \param attrs The DictAttrs to update
360-
*
361-
* \param key The key to remove
362-
*
363-
* \returns The new DictAttrs with updated attributes.
364-
*/
365-
DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);
366-
367318
/*!
368319
* \brief Copy the function or module, but overrides
369320
* the attribute value key with the value.
@@ -396,8 +347,12 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v
396347
using TNode = typename TFunc::ContainerType;
397348
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
398349
TNode* node = input.CopyOnWrite();
399-
node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);
400-
350+
if (node->attrs.defined()) {
351+
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
352+
} else {
353+
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
354+
node->attrs = DictAttrs(dict);
355+
}
401356
return input;
402357
}
403358

@@ -416,9 +371,13 @@ inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
416371
using TNode = typename TFunc::ContainerType;
417372
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
418373
TNode* node = input.CopyOnWrite();
419-
420-
node->attrs = WithAttrs(std::move(node->attrs), attrs);
421-
374+
if (node->attrs.defined()) {
375+
for (const auto& pair : attrs) {
376+
node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
377+
}
378+
} else {
379+
node->attrs = DictAttrs(std::move(attrs));
380+
}
422381
return input;
423382
}
424383

@@ -453,9 +412,10 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
453412
using TNode = typename TFunc::ContainerType;
454413
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
455414

456-
TNode* node = input.CopyOnWrite();
457-
node->attrs = WithoutAttr(std::move(node->attrs), attr_key);
458-
415+
if (input->attrs.defined()) {
416+
TNode* node = input.CopyOnWrite();
417+
node->attrs.CopyOnWrite()->dict.erase(attr_key);
418+
}
459419
return input;
460420
}
461421

include/tvm/ir/expr.h

Lines changed: 31 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -770,121 +770,53 @@ inline const TTypeNode* RelayExprNode::type_as() const {
770770

771771
namespace tvm {
772772
namespace runtime {
773-
774-
// Automatic conversion into IntImm, Integer, and Bool, when called
775-
// through the FFI. Automatic conversions into PrimExpr are
776-
// registered in "tvm/tir/expr.h", as it includes conversions to the
777-
// TIR-only StringImm.
778-
//
779-
// While the FFI only requires the From() method, these
780-
// implementations also define a TryFrom() method to avoid duplicate
781-
// logic in the PrimExpr conversion.
782-
773+
// common rule for RetValue and ArgValue
783774
template <>
784-
struct PackedFuncValueConverter<tvm::IntImm> {
785-
template <typename PODSubclass>
786-
static Optional<tvm::IntImm> TryFrom(const PODSubclass& val) {
787-
if (auto opt = val.TryAsInt()) {
788-
int64_t value = opt.value();
789-
auto dtype =
790-
(value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min())
791-
? DataType::Int(64)
792-
: DataType::Int(32);
793-
return IntImm(dtype, value);
794-
} else if (auto opt = val.TryAsBool()) {
795-
return IntImm(DataType::Int(32), opt.value());
796-
} else {
797-
return NullOpt;
775+
struct PackedFuncValueConverter<PrimExpr> {
776+
static PrimExpr From(const TVMPODValue_& val) {
777+
if (val.type_code() == kTVMNullptr) {
778+
return PrimExpr(ObjectPtr<Object>(nullptr));
798779
}
799-
}
800-
801-
template <typename PODSubclass>
802-
static tvm::IntImm From(const PODSubclass& val) {
803-
if (auto opt = TryFrom(val)) {
804-
return opt.value();
805-
} else {
806-
return val.template AsObjectRef<tvm::IntImm>();
780+
if (val.type_code() == kDLInt) {
781+
int64_t value = val.operator int64_t();
782+
if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) {
783+
return IntImm(runtime::DataType::Int(64), value);
784+
}
785+
return IntImm(runtime::DataType::Int(32), val.operator int());
807786
}
808-
}
809-
};
810-
811-
template <>
812-
struct PackedFuncValueConverter<tvm::Integer> {
813-
template <typename PODSubclass>
814-
static tvm::Integer From(const PODSubclass& val) {
815-
if (auto opt = PackedFuncValueConverter<tvm::IntImm>::TryFrom(val)) {
816-
return Integer(opt.value());
817-
} else {
818-
return val.template AsObjectRef<tvm::Integer>();
787+
if (val.type_code() == kDLFloat) {
788+
return FloatImm(runtime::DataType::Float(32), val.operator double());
819789
}
820-
}
821-
};
822790

823-
template <>
824-
struct PackedFuncValueConverter<tvm::Bool> {
825-
template <typename PODSubclass>
826-
static Optional<tvm::Bool> TryFrom(const PODSubclass& val) {
827-
if (auto opt = val.TryAsBool()) {
828-
return tvm::Bool(opt.value());
829-
} else if (auto opt = val.TryAsInt()) {
830-
int value = opt.value();
831-
ICHECK(value == 0 || value == 1)
832-
<< "ValueError: boolean value can only be 0 or 1, but get " << value;
833-
return tvm::Bool(static_cast<bool>(value));
834-
} else {
835-
return NullOpt;
836-
}
837-
}
838-
839-
template <typename PODSubclass>
840-
static tvm::Bool From(const PODSubclass& val) {
841-
if (auto opt = TryFrom(val)) {
842-
return opt.value();
843-
} else {
844-
return val.template AsObjectRef<tvm::Bool>();
845-
}
791+
return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
846792
}
847793
};
848794

849795
template <>
850-
struct PackedFuncValueConverter<tvm::FloatImm> {
851-
static Optional<tvm::FloatImm> TryFrom(const TVMPODValue_& val) {
852-
if (auto opt = val.TryAsFloat()) {
853-
return FloatImm(runtime::DataType::Float(32), opt.value());
854-
} else {
855-
return NullOpt;
796+
struct PackedFuncValueConverter<tvm::Integer> {
797+
static tvm::Integer From(const TVMPODValue_& val) {
798+
if (val.type_code() == kTVMNullptr) {
799+
return Integer(ObjectPtr<Object>(nullptr));
856800
}
857-
}
858-
859-
template <typename PODSubclass>
860-
static tvm::FloatImm From(const PODSubclass& val) {
861-
if (auto opt = TryFrom(val)) {
862-
return opt.value();
863-
} else {
864-
return val.template AsObjectRef<tvm::FloatImm>();
801+
if (val.type_code() == kTVMArgInt) {
802+
return Integer(val.operator int());
865803
}
804+
return val.AsObjectRef<tvm::Integer>();
866805
}
867806
};
868807

869-
/* \brief Backwards compatibility wrapper for IntImm arguments
870-
*
871-
* In previous versions of TVM, IntImm was the default FFI type for
872-
* integer arguments, instead of runtime::Int. For backwards
873-
* compatibility where the callee has been updated to expected a
874-
* runtime::Int, the caller has not been updated to provide a
875-
* runtime::Int (e.g. relay script parsing), and the auto-unboxing of
876-
* runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
877-
* allow the IntImm to be generated.
878-
*/
879808
template <>
880-
struct PackedFuncValueConverter<runtime::Int> {
881-
template <typename PODSubclass>
882-
static runtime::Int From(const PODSubclass& val) {
883-
if (val.template IsObjectRef<tvm::IntImm>()) {
884-
return runtime::Int(val.template AsObjectRef<tvm::IntImm>()->value);
885-
} else {
886-
return val.template AsObjectRef<runtime::Int>();
809+
struct PackedFuncValueConverter<tvm::Bool> {
810+
static tvm::Bool From(const TVMPODValue_& val) {
811+
if (val.type_code() == kTVMNullptr) {
812+
return Bool(ObjectPtr<Object>(nullptr));
813+
}
814+
if (val.type_code() == kTVMArgInt) {
815+
int v = val.operator int();
816+
ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
817+
return Bool(static_cast<bool>(v));
887818
}
819+
return val.AsObjectRef<tvm::Bool>();
888820
}
889821
};
890822

include/tvm/ir/transform.h

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -271,36 +271,7 @@ class PassContext : public ObjectRef {
271271
using ValueNodeType = typename ValueType::ContainerType;
272272
// NOTE: we could further update the function later.
273273
uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
274-
auto type_key = runtime::Object::TypeIndex2Key(tindex);
275-
276-
auto* reflection = ReflectionVTable::Global();
277-
278-
auto legalization = [=](ObjectRef obj) -> ObjectRef {
279-
if (obj->IsInstance<Map<String, ObjectRef>::ContainerType>()) {
280-
return reflection->CreateObject(type_key, Downcast<Map<String, ObjectRef>>(obj));
281-
} else {
282-
// Backwards compatibility for config options defined prior to
283-
// https://github.com/apache/tvm/pull/16183. This commit
284-
// changed the default FFI conversion of python integers from
285-
// `tvm::IntImm` to `runtime::Int`.
286-
//
287-
// This backwards compatibility fix can be removed when all
288-
// options registered with TVM_REGISTER_PASS_CONFIG_OPTION are
289-
// updated to use `runtime::Int` and `runtime::Bool`.
290-
TVMRetValue ret;
291-
ret = obj;
292-
try {
293-
ValueType legalized = ret;
294-
return legalized;
295-
} catch (Error& err) {
296-
LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key
297-
<< ", but received error when converting to this type.\n"
298-
<< err.what();
299-
}
300-
}
301-
};
302-
303-
RegisterConfigOption(key, tindex, legalization);
274+
RegisterConfigOption(key, tindex);
304275
return tindex;
305276
}
306277

@@ -314,8 +285,7 @@ class PassContext : public ObjectRef {
314285
// The exit of a pass context scope.
315286
TVM_DLL void ExitWithScope();
316287
// Register configuration key value type.
317-
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index,
318-
std::function<ObjectRef(ObjectRef)> legalization);
288+
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
319289

320290
// Classes to get the Python `with` like syntax.
321291
friend class Internal;

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef {
241241
* \param thread_extents Candidates of thread axis extent (values are required to be positive).
242242
* \return The schedule rule created
243243
*/
244-
TVM_DLL static ScheduleRule CrossThreadReduction(Array<runtime::Int> thread_extents);
244+
TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
245245
/*!
246246
* \brief A rule that randomly select a compute-at location for a free block
247247
* \return The schedule rule created
@@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef {
260260
* \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma.
261261
* \return The schedule rule created
262262
*/
263-
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
264-
int max_vectorize_extent, //
265-
Array<runtime::Int> unroll_max_steps, //
263+
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
264+
int max_vectorize_extent, //
265+
Array<Integer> unroll_max_steps, //
266266
bool unroll_explicit);
267267
/*!
268268
* \brief Auto bind loops around the block to BlockIdx and ThreadIdx

include/tvm/relay/attrs/transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
325325
}; // struct SqueezeAttrs
326326

327327
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
328-
Variant<runtime::Int, Array<runtime::Int>> indices_or_sections;
328+
ObjectRef indices_or_sections;
329329
int axis;
330330

331331
TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {

include/tvm/runtime/c_runtime_api.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
#ifdef __cplusplus
8282
extern "C" {
8383
#endif
84-
#include <stdbool.h>
8584
#include <stddef.h>
8685
#include <stdint.h>
8786

@@ -187,12 +186,11 @@ typedef enum {
187186
kTVMBytes = 12U,
188187
kTVMNDArrayHandle = 13U,
189188
kTVMObjectRValueRefArg = 14U,
190-
kTVMArgBool = 15U,
191189
// Extension codes for other frameworks to integrate TVM PackedFunc.
192190
// To make sure each framework's id do not conflict, use first and
193191
// last sections to mark ranges.
194192
// Open an issue at the repo if you need a section of code.
195-
kTVMExtBegin = 16U,
193+
kTVMExtBegin = 15U,
196194
kTVMNNVMFirst = 16U,
197195
kTVMNNVMLast = 20U,
198196
// The following section of code is used for non-reserved types.
@@ -209,7 +207,6 @@ typedef DLTensor* TVMArrayHandle;
209207
*/
210208
typedef union {
211209
int64_t v_int64;
212-
bool v_bool;
213210
double v_float64;
214211
void* v_handle;
215212
const char* v_str;

0 commit comments

Comments
 (0)