Skip to content

Commit 0f037a6

Browse files
authored
[FFI][Runtime] Use TVMValue::v_int64 to represent boolean values (#17240)
* [FFI][Runtime] Use TVMValue::v_int64 to represent boolean values This is a follow-up to #16183, which added handling of boolean values in the TVM FFI. The initial implementation added both a new type code (`kTVMArgBool`) and a new `TVMValue::v_bool` variant. This commit removes the `TVMValue::v_bool` variant, since the `kTVMArgBool` type code is sufficient to handle boolean arguments. Removing the `TVMValue::v_bool` variant also makes all `TVMValue` variants be 64-bit (assuming a 64-bit CPU). This can simplify debugging in some cases, since it prevents partial values from inactive variants from being present in memory. * Update MakePackedAPI, less special handling required for boolean
1 parent 20289e8 commit 0f037a6

File tree

11 files changed

+39
-41
lines changed

11 files changed

+39
-41
lines changed

include/tvm/runtime/c_runtime_api.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ typedef DLTensor* TVMArrayHandle;
209209
*/
210210
typedef union {
211211
int64_t v_int64;
212-
bool v_bool;
213212
double v_float64;
214213
void* v_handle;
215214
const char* v_str;

include/tvm/runtime/packed_func.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ class TVMPODValue_ {
669669
// conversions. This is publicly exposed, as it can be useful in
670670
// specializations of PackedFuncValueConverter.
671671
if (type_code_ == kTVMArgBool) {
672-
return value_.v_bool;
672+
return static_cast<bool>(value_.v_int64);
673673
} else {
674674
return std::nullopt;
675675
}
@@ -1041,7 +1041,7 @@ class TVMRetValue : public TVMPODValue_CRTP_<TVMRetValue> {
10411041
TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
10421042
TVMRetValue& operator=(bool value) {
10431043
this->SwitchToPOD(kTVMArgBool);
1044-
value_.v_bool = value;
1044+
value_.v_int64 = value;
10451045
return *this;
10461046
}
10471047
TVMRetValue& operator=(std::string value) {
@@ -1831,7 +1831,7 @@ class TVMArgsSetter {
18311831
type_codes_[i] = kDLInt;
18321832
}
18331833
TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const {
1834-
values_[i].v_bool = value;
1834+
values_[i].v_int64 = value;
18351835
type_codes_[i] = kTVMArgBool;
18361836
}
18371837
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
@@ -2142,7 +2142,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
21422142
std::is_base_of_v<ContainerType, Bool::ContainerType>) {
21432143
if (std::is_base_of_v<Bool::ContainerType, ContainerType> ||
21442144
ptr->IsInstance<Bool::ContainerType>()) {
2145-
values_[i].v_bool = static_cast<Bool::ContainerType*>(ptr)->value;
2145+
values_[i].v_int64 = static_cast<Bool::ContainerType*>(ptr)->value;
21462146
type_codes_[i] = kTVMArgBool;
21472147
return;
21482148
}
@@ -2327,7 +2327,7 @@ inline TObjectRef TVMPODValue_CRTP_<Derived>::AsObjectRef() const {
23272327

23282328
if constexpr (std::is_base_of_v<TObjectRef, Bool>) {
23292329
if (type_code_ == kTVMArgBool) {
2330-
return Bool(value_.v_bool);
2330+
return Bool(value_.v_int64);
23312331
}
23322332
}
23332333

python/tvm/_ffi/_cython/packed_func.pxi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ cdef inline int make_arg(object arg,
121121
elif isinstance(arg, bool):
122122
# A python `bool` is a subclass of `int`, so this check
123123
# must occur before `Integral`.
124-
value[0].v_bool = arg
124+
value[0].v_int64 = arg
125125
tcode[0] = kTVMArgBool
126126
elif isinstance(arg, Integral):
127127
value[0].v_int64 = arg
@@ -215,7 +215,7 @@ cdef inline object make_ret(TVMValue value, int tcode):
215215
elif tcode == kTVMNullptr:
216216
return None
217217
elif tcode == kTVMArgBool:
218-
return value.v_bool
218+
return bool(value.v_int64)
219219
elif tcode == kInt:
220220
return value.v_int64
221221
elif tcode == kFloat:

rust/tvm-sys/src/packed_func.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ macro_rules! TVMPODValue {
9696
DLDataTypeCode_kDLInt => Int($value.v_int64),
9797
DLDataTypeCode_kDLUInt => UInt($value.v_int64),
9898
DLDataTypeCode_kDLFloat => Float($value.v_float64),
99-
TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool),
99+
TVMArgTypeCode_kTVMArgBool => Bool($value.v_int64 != 0),
100100
TVMArgTypeCode_kTVMNullptr => Null,
101101
TVMArgTypeCode_kTVMDataType => DataType($value.v_type),
102102
TVMArgTypeCode_kDLDevice => Device($value.v_device),
@@ -119,7 +119,7 @@ macro_rules! TVMPODValue {
119119
Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
120120
UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
121121
Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
122-
Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool),
122+
Bool(val) => (TVMValue { v_int64: *val as i64 }, TVMArgTypeCode_kTVMArgBool),
123123
Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr),
124124
DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType),
125125
Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice),

src/runtime/crt/common/crt_runtime_api.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,8 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r
362362
return kTvmErrorFunctionCallWrongArgType;
363363
}
364364

365-
if (type_codes[2] == kDLInt) {
365+
if (type_codes[2] == kDLInt || type_codes[2] == kTVMArgBool) {
366366
query_imports = args[2].v_int64 != 0;
367-
} else if (type_codes[2] == kTVMArgBool) {
368-
query_imports = args[2].v_bool;
369367
} else {
370368
TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer");
371369
return kTvmErrorFunctionCallWrongArgType;

src/runtime/minrpc/rpc_reference.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ struct RPCReference {
326326
break;
327327
}
328328
case kTVMArgBool: {
329-
channel->template Write<bool>(value.v_bool);
329+
channel->template Write<int64_t>(value.v_int64);
330330
break;
331331
}
332332
case kTVMDataType: {
@@ -437,7 +437,7 @@ struct RPCReference {
437437
break;
438438
}
439439
case kTVMArgBool: {
440-
channel->template Read<bool>(&(value.v_bool));
440+
channel->template Read<int64_t>(&(value.v_int64));
441441
break;
442442
}
443443
case kTVMDataType: {

src/target/llvm/codegen_cpu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
13791379
llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr);
13801380

13811381
if (op->dtype == DataType::Bool()) {
1382-
struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value);
1382+
struct_value = CreateCast(DataType::Int(64), op->dtype, struct_value);
13831383
}
13841384

13851385
return struct_value;

src/tir/transforms/ir_utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ inline DataType APIType(DataType t) {
155155
ICHECK(!t.is_void()) << "Cannot pass void type through packed API.";
156156
if (t.is_handle()) return t;
157157
ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API.";
158-
if (t.is_bool()) return DataType::Bool();
159-
if (t.is_uint() || t.is_int()) return DataType::Int(64);
158+
if (t.is_bool() || t.is_uint() || t.is_int()) return DataType::Int(64);
160159
ICHECK(t.is_float());
161160
return DataType::Float(64);
162161
}

src/tir/transforms/make_packed_api.cc

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ class ReturnRewriter : public StmtMutator {
8181

8282
// convert val's data type to FFI data type, return type code
8383
DataType dtype = val.dtype();
84-
if (dtype.is_int() || dtype.is_uint()) {
84+
if (dtype.is_bool()) {
85+
info.tcode = kTVMArgBool;
86+
info.expr = Cast(DataType::Int(64), val);
87+
88+
} else if (dtype.is_int() || dtype.is_uint()) {
8589
info.tcode = kTVMArgInt;
8690
info.expr = Cast(DataType::Int(64), val);
8791
} else if (dtype.is_float()) {
@@ -340,25 +344,15 @@ PrimFunc MakePackedAPI(PrimFunc func) {
340344
seq_init.emplace_back(
341345
AssertStmt(tcode == kTVMArgBool || tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
342346

343-
arg_value = Call(t, builtin::if_then_else(),
344-
{
345-
tcode == kTVMArgBool,
346-
f_arg_value(DataType::Bool(), i),
347-
cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)),
348-
});
347+
arg_value = cast(DataType::Bool(), f_arg_value(DataType::Int(64), i));
349348

350349
} else if (t.is_int() || t.is_uint()) {
351350
std::ostringstream msg;
352351
msg << name_hint << ": Expect arg[" << i << "] to be int";
353352
seq_init.emplace_back(
354353
AssertStmt(tcode == kDLInt || tcode == kTVMArgBool, tvm::tir::StringImm(msg.str()), nop));
355354

356-
arg_value = Call(t, builtin::if_then_else(),
357-
{
358-
tcode == kTVMArgInt,
359-
f_arg_value(t, i),
360-
cast(t, f_arg_value(DataType::Bool(), i)),
361-
});
355+
arg_value = f_arg_value(t, i);
362356
} else {
363357
ICHECK(t.is_float());
364358
std::ostringstream msg;

tests/python/codegen/test_target_codegen_llvm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,5 +1179,21 @@ def func(arg: T.bool) -> T.int32:
11791179
assert output == 20
11801180

11811181

1182+
def test_bool_return_value():
1183+
"""Booleans may be returned from a PrimFunc"""
1184+
1185+
@T.prim_func
1186+
def func(value: T.int32) -> T.bool:
1187+
T.func_attr({"target": T.target("llvm")})
1188+
return value < 10
1189+
1190+
built = tvm.build(func)
1191+
assert isinstance(built(0), bool)
1192+
assert built(0)
1193+
1194+
assert isinstance(built(15), bool)
1195+
assert not built(15)
1196+
1197+
11821198
if __name__ == "__main__":
11831199
tvm.testing.main()

0 commit comments

Comments
 (0)