Skip to content

Commit 02f4882

Browse files
authored
[FFI] Re-introduce the boxed primitive values (#17257)
* Revert "Revert "[FFI][RUNTIME] Introduce runtime boxed types for int/float/bool" (#17252)" This reverts commit 11be832. * [FFI] Re-introduce the boxed primitive values Initially introduced in #16183, these changes were reverted in #17252 due to performance degredation in some Relax models. This could occur when a model contained a large number of calls to `"vm.builtin.tuple_getitem"`, which may occur when model weights are provided as a tuple. This PR re-applies the changes from #16183, but with the performance degredation resolved. The root cause was unnecessary type-checking when converting from an untyped `tvm::ArrayNode*` to the typed `tvm::Array<T>`, in the case where `T` is `ObjectRef`. * Correct typo from T to U
1 parent b3d01c2 commit 02f4882

File tree

184 files changed

+3278
-1225
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

184 files changed

+3278
-1225
lines changed

include/tvm/ir/attrs.h

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,16 @@ class DictAttrs : public Attrs {
265265

266266
auto it = node->dict.find(attr_key);
267267
if (it != node->dict.end()) {
268-
return Downcast<Optional<TObjectRef>>((*it).second);
268+
// For backwards compatibility, return through TVMRetValue.
269+
// This triggers any automatic conversions registered with
270+
// PackedFuncValueConverter. Importantly, this allows use of
271+
// `GetAttr<Integer>` and `GetAttr<Bool>` for properties that
272+
// are stored internally as `runtime::Box<int64_t>` and
273+
// `runtime::Box<bool>`.
274+
TVMRetValue ret;
275+
ret = (*it).second;
276+
Optional<TObjectRef> obj = ret;
277+
return obj;
269278
} else {
270279
return default_value;
271280
}
@@ -315,6 +324,46 @@ inline TAttrs AttrsWithDefaultValues() {
315324
return TAttrs(n);
316325
}
317326

327+
/*!
328+
* \brief Copy the DictAttrs, but overrides attributes with the
329+
* entries from \p attrs.
330+
*
331+
* \param attrs The DictAttrs to update
332+
*
333+
* \param new_attrs Key/values attributes to add to \p attrs.
334+
*
335+
* \returns The new DictAttrs with updated attributes.
336+
*/
337+
DictAttrs WithAttrs(DictAttrs attrs, Map<String, ObjectRef> new_attrs);
338+
339+
/*!
340+
* \brief Copy the DictAttrs, but overrides a single attribute.
341+
*
342+
* \param attrs The DictAttrs to update
343+
*
344+
* \param key The update to insert or update.
345+
*
346+
* \param value The new value of the attribute
347+
*
348+
* \returns The new DictAttrs with updated attributes.
349+
*/
350+
DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value);
351+
352+
inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) {
353+
return WithAttr(std::move(attrs), String(key), std::move(value));
354+
}
355+
356+
/*!
357+
* \brief Copy the DictAttrs, but without a specific attribute.
358+
*
359+
* \param attrs The DictAttrs to update
360+
*
361+
* \param key The key to remove
362+
*
363+
* \returns The new DictAttrs with updated attributes.
364+
*/
365+
DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);
366+
318367
/*!
319368
* \brief Copy the function or module, but overrides
320369
* the attribute value key with the value.
@@ -347,12 +396,8 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v
347396
using TNode = typename TFunc::ContainerType;
348397
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
349398
TNode* node = input.CopyOnWrite();
350-
if (node->attrs.defined()) {
351-
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
352-
} else {
353-
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
354-
node->attrs = DictAttrs(dict);
355-
}
399+
node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);
400+
356401
return input;
357402
}
358403

@@ -371,13 +416,9 @@ inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
371416
using TNode = typename TFunc::ContainerType;
372417
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
373418
TNode* node = input.CopyOnWrite();
374-
if (node->attrs.defined()) {
375-
for (const auto& pair : attrs) {
376-
node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
377-
}
378-
} else {
379-
node->attrs = DictAttrs(std::move(attrs));
380-
}
419+
420+
node->attrs = WithAttrs(std::move(node->attrs), attrs);
421+
381422
return input;
382423
}
383424

@@ -412,10 +453,9 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
412453
using TNode = typename TFunc::ContainerType;
413454
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
414455

415-
if (input->attrs.defined()) {
416-
TNode* node = input.CopyOnWrite();
417-
node->attrs.CopyOnWrite()->dict.erase(attr_key);
418-
}
456+
TNode* node = input.CopyOnWrite();
457+
node->attrs = WithoutAttr(std::move(node->attrs), attr_key);
458+
419459
return input;
420460
}
421461

include/tvm/ir/expr.h

Lines changed: 99 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -770,53 +770,121 @@ inline const TTypeNode* RelayExprNode::type_as() const {
770770

771771
namespace tvm {
772772
namespace runtime {
773-
// common rule for RetValue and ArgValue
773+
774+
// Automatic conversion into IntImm, Integer, and Bool, when called
775+
// through the FFI. Automatic conversions into PrimExpr are
776+
// registered in "tvm/tir/expr.h", as it includes conversions to the
777+
// TIR-only StringImm.
778+
//
779+
// While the FFI only requires the From() method, these
780+
// implementations also define a TryFrom() method to avoid duplicate
781+
// logic in the PrimExpr conversion.
782+
774783
template <>
775-
struct PackedFuncValueConverter<PrimExpr> {
776-
static PrimExpr From(const TVMPODValue_& val) {
777-
if (val.type_code() == kTVMNullptr) {
778-
return PrimExpr(ObjectPtr<Object>(nullptr));
779-
}
780-
if (val.type_code() == kDLInt) {
781-
int64_t value = val.operator int64_t();
782-
if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) {
783-
return IntImm(runtime::DataType::Int(64), value);
784-
}
785-
return IntImm(runtime::DataType::Int(32), val.operator int());
786-
}
787-
if (val.type_code() == kDLFloat) {
788-
return FloatImm(runtime::DataType::Float(32), val.operator double());
784+
struct PackedFuncValueConverter<tvm::IntImm> {
785+
template <typename PODSubclass>
786+
static Optional<tvm::IntImm> TryFrom(const PODSubclass& val) {
787+
if (auto opt = val.TryAsInt()) {
788+
int64_t value = opt.value();
789+
auto dtype =
790+
(value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min())
791+
? DataType::Int(64)
792+
: DataType::Int(32);
793+
return IntImm(dtype, value);
794+
} else if (auto opt = val.TryAsBool()) {
795+
return IntImm(DataType::Int(32), opt.value());
796+
} else {
797+
return NullOpt;
789798
}
799+
}
790800

791-
return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
801+
template <typename PODSubclass>
802+
static tvm::IntImm From(const PODSubclass& val) {
803+
if (auto opt = TryFrom(val)) {
804+
return opt.value();
805+
} else {
806+
return val.template AsObjectRef<tvm::IntImm>();
807+
}
792808
}
793809
};
794810

795811
template <>
796812
struct PackedFuncValueConverter<tvm::Integer> {
797-
static tvm::Integer From(const TVMPODValue_& val) {
798-
if (val.type_code() == kTVMNullptr) {
799-
return Integer(ObjectPtr<Object>(nullptr));
813+
template <typename PODSubclass>
814+
static tvm::Integer From(const PODSubclass& val) {
815+
if (auto opt = PackedFuncValueConverter<tvm::IntImm>::TryFrom(val)) {
816+
return Integer(opt.value());
817+
} else {
818+
return val.template AsObjectRef<tvm::Integer>();
800819
}
801-
if (val.type_code() == kTVMArgInt) {
802-
return Integer(val.operator int());
803-
}
804-
return val.AsObjectRef<tvm::Integer>();
805820
}
806821
};
807822

808823
template <>
809824
struct PackedFuncValueConverter<tvm::Bool> {
810-
static tvm::Bool From(const TVMPODValue_& val) {
811-
if (val.type_code() == kTVMNullptr) {
812-
return Bool(ObjectPtr<Object>(nullptr));
825+
template <typename PODSubclass>
826+
static Optional<tvm::Bool> TryFrom(const PODSubclass& val) {
827+
if (auto opt = val.TryAsBool()) {
828+
return tvm::Bool(opt.value());
829+
} else if (auto opt = val.TryAsInt()) {
830+
int value = opt.value();
831+
ICHECK(value == 0 || value == 1)
832+
<< "ValueError: boolean value can only be 0 or 1, but get " << value;
833+
return tvm::Bool(static_cast<bool>(value));
834+
} else {
835+
return NullOpt;
836+
}
837+
}
838+
839+
template <typename PODSubclass>
840+
static tvm::Bool From(const PODSubclass& val) {
841+
if (auto opt = TryFrom(val)) {
842+
return opt.value();
843+
} else {
844+
return val.template AsObjectRef<tvm::Bool>();
813845
}
814-
if (val.type_code() == kTVMArgInt) {
815-
int v = val.operator int();
816-
ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
817-
return Bool(static_cast<bool>(v));
846+
}
847+
};
848+
849+
template <>
850+
struct PackedFuncValueConverter<tvm::FloatImm> {
851+
static Optional<tvm::FloatImm> TryFrom(const TVMPODValue_& val) {
852+
if (auto opt = val.TryAsFloat()) {
853+
return FloatImm(runtime::DataType::Float(32), opt.value());
854+
} else {
855+
return NullOpt;
856+
}
857+
}
858+
859+
template <typename PODSubclass>
860+
static tvm::FloatImm From(const PODSubclass& val) {
861+
if (auto opt = TryFrom(val)) {
862+
return opt.value();
863+
} else {
864+
return val.template AsObjectRef<tvm::FloatImm>();
865+
}
866+
}
867+
};
868+
869+
/* \brief Backwards compatibility wrapper for IntImm arguments
870+
*
871+
* In previous versions of TVM, IntImm was the default FFI type for
872+
* integer arguments, instead of runtime::Int. For backwards
873+
* compatibility where the callee has been updated to expected a
874+
* runtime::Int, the caller has not been updated to provide a
875+
* runtime::Int (e.g. relay script parsing), and the auto-unboxing of
876+
* runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
877+
* allow the IntImm to be generated.
878+
*/
879+
template <>
880+
struct PackedFuncValueConverter<runtime::Int> {
881+
template <typename PODSubclass>
882+
static runtime::Int From(const PODSubclass& val) {
883+
if (val.template IsObjectRef<tvm::IntImm>()) {
884+
return runtime::Int(val.template AsObjectRef<tvm::IntImm>()->value);
885+
} else {
886+
return val.template AsObjectRef<runtime::Int>();
818887
}
819-
return val.AsObjectRef<tvm::Bool>();
820888
}
821889
};
822890

include/tvm/ir/transform.h

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,36 @@ class PassContext : public ObjectRef {
271271
using ValueNodeType = typename ValueType::ContainerType;
272272
// NOTE: we could further update the function later.
273273
uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
274-
RegisterConfigOption(key, tindex);
274+
auto type_key = runtime::Object::TypeIndex2Key(tindex);
275+
276+
auto* reflection = ReflectionVTable::Global();
277+
278+
auto legalization = [=](ObjectRef obj) -> ObjectRef {
279+
if (obj->IsInstance<Map<String, ObjectRef>::ContainerType>()) {
280+
return reflection->CreateObject(type_key, Downcast<Map<String, ObjectRef>>(obj));
281+
} else {
282+
// Backwards compatibility for config options defined prior to
283+
// https://github.com/apache/tvm/pull/16183. This commit
284+
// changed the default FFI conversion of python integers from
285+
// `tvm::IntImm` to `runtime::Int`.
286+
//
287+
// This backwards compatibility fix can be removed when all
288+
// options registered with TVM_REGISTER_PASS_CONFIG_OPTION are
289+
// updated to use `runtime::Int` and `runtime::Bool`.
290+
TVMRetValue ret;
291+
ret = obj;
292+
try {
293+
ValueType legalized = ret;
294+
return legalized;
295+
} catch (Error& err) {
296+
LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key
297+
<< ", but received error when converting to this type.\n"
298+
<< err.what();
299+
}
300+
}
301+
};
302+
303+
RegisterConfigOption(key, tindex, legalization);
275304
return tindex;
276305
}
277306

@@ -285,7 +314,8 @@ class PassContext : public ObjectRef {
285314
// The exit of a pass context scope.
286315
TVM_DLL void ExitWithScope();
287316
// Register configuration key value type.
288-
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
317+
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index,
318+
std::function<ObjectRef(ObjectRef)> legalization);
289319

290320
// Classes to get the Python `with` like syntax.
291321
friend class Internal;

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef {
241241
* \param thread_extents Candidates of thread axis extent (values are required to be positive).
242242
* \return The schedule rule created
243243
*/
244-
TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
244+
TVM_DLL static ScheduleRule CrossThreadReduction(Array<runtime::Int> thread_extents);
245245
/*!
246246
* \brief A rule that randomly select a compute-at location for a free block
247247
* \return The schedule rule created
@@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef {
260260
* \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma.
261261
* \return The schedule rule created
262262
*/
263-
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
264-
int max_vectorize_extent, //
265-
Array<Integer> unroll_max_steps, //
263+
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
264+
int max_vectorize_extent, //
265+
Array<runtime::Int> unroll_max_steps, //
266266
bool unroll_explicit);
267267
/*!
268268
* \brief Auto bind loops around the block to BlockIdx and ThreadIdx

include/tvm/relay/attrs/transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
325325
}; // struct SqueezeAttrs
326326

327327
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
328-
ObjectRef indices_or_sections;
328+
Variant<runtime::Int, Array<runtime::Int>> indices_or_sections;
329329
int axis;
330330

331331
TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {

include/tvm/runtime/c_runtime_api.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
#ifdef __cplusplus
8282
extern "C" {
8383
#endif
84+
#include <stdbool.h>
8485
#include <stddef.h>
8586
#include <stdint.h>
8687

@@ -186,11 +187,12 @@ typedef enum {
186187
kTVMBytes = 12U,
187188
kTVMNDArrayHandle = 13U,
188189
kTVMObjectRValueRefArg = 14U,
190+
kTVMArgBool = 15U,
189191
// Extension codes for other frameworks to integrate TVM PackedFunc.
190192
// To make sure each framework's id do not conflict, use first and
191193
// last sections to mark ranges.
192194
// Open an issue at the repo if you need a section of code.
193-
kTVMExtBegin = 15U,
195+
kTVMExtBegin = 16U,
194196
kTVMNNVMFirst = 16U,
195197
kTVMNNVMLast = 20U,
196198
// The following section of code is used for non-reserved types.
@@ -207,6 +209,7 @@ typedef DLTensor* TVMArrayHandle;
207209
*/
208210
typedef union {
209211
int64_t v_int64;
212+
bool v_bool;
210213
double v_float64;
211214
void* v_handle;
212215
const char* v_str;

0 commit comments

Comments
 (0)