Skip to content

Commit ff0b07b

Browse files
authored
[TIR] Add is_vector Method to DataType class and update usages across Codebase (#17443)
* Refactor data_type.h and c_runtime_api.h This commit refactors the `data_type.h` and `c_runtime_api.h` files. It introduces a new function `is_vector()` in the `DataType` class to check if a type is a vector type. Additionally, it adds a new constant `kTVMGridConstant` in the `TVMTypeCode` enum in `c_runtime_api.h`. These changes improve the code organization and provide better support for vector types. * revert kTVMGridConstant * lint fix
1 parent accd582 commit ff0b07b

File tree

5 files changed

+12
-10
lines changed

5 files changed

+12
-10
lines changed

include/tvm/runtime/data_type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ class DataType {
148148
bool is_fixed_length_vector() const { return static_cast<int16_t>(data_.lanes) > 1; }
149149
/*! \return Whether the type is a scalable vector. */
150150
bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < -1; }
151+
/*! \return whether type is a vector type. */
152+
bool is_vector() const { return lanes() > 1; }
151153
/*! \return whether type is a bool vector type. */
152154
bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
153155
/*! \return whether type is a Void type. */

include/tvm/topi/elemwise.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
287287
if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) {
288288
if (expr.dtype().lanes() == type.lanes()) {
289289
return expr;
290-
} else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
290+
} else if (expr.dtype().lanes() == 1 && type.is_vector()) {
291291
return tvm::tir::Broadcast(expr, type.lanes());
292292
}
293293
}

src/target/llvm/codegen_llvm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper(
17371737
if (const RampNode* ramp = last_index.as<RampNode>()) {
17381738
PrimExpr offset = ramp->base + (ramp->stride * i);
17391739
last_index_value = MakeValue(offset);
1740-
} else if (last_index.dtype().lanes() > 1) {
1740+
} else if (last_index.dtype().is_vector()) {
17411741
if (i == 0) {
17421742
cached_vector_index = MakeValue(last_index);
17431743
}

src/target/llvm/intrin_rule_hexagon.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) {
6666

6767
// Enable QHL library for FP16 data type
6868
const PrimExpr& x = call->args[0];
69-
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
69+
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
7070
return TVMExternCall(call, tvm_wrapper);
7171
}
7272
#endif
@@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh")
116116
}
117117

118118
// Enable QHL library for FP16 data type
119-
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
119+
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
120120
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf");
121121
return TVMExternCall(call, tvm_wrapper);
122122
}
@@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
152152
}
153153

154154
// Enable QHL library for FP16 data type
155-
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
155+
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
156156
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf");
157157
return TVMExternCall(call, tvm_wrapper);
158158
}
@@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid")
191191
const tir::Call new_call = tir::Call(call->dtype, call->op, new_args);
192192

193193
// Enable QHL library for FP16 data type
194-
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
194+
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
195195
std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf");
196196
return TVMExternCall(new_call.get(), tvm_wrapper);
197197
}

src/tir/analysis/verify_gpu_code.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
7171
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
7272
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
7373
}
74-
if (op->dtype.lanes() > 1) {
74+
if (op->dtype.is_vector()) {
7575
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
7676
std::stringstream s;
7777
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
@@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
202202
}
203203

204204
void VisitExpr_(const CastNode* op) {
205-
if (op->dtype.lanes() > 1) {
205+
if (op->dtype.is_vector()) {
206206
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
207207
std::stringstream s;
208208
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
@@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
215215
}
216216

217217
void VisitExpr_(const BufferLoadNode* op) {
218-
if (op->dtype.lanes() > 1) {
218+
if (op->dtype.is_vector()) {
219219
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
220220
std::stringstream s;
221221
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
@@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
229229
}
230230

231231
void VisitStmt_(const BufferStoreNode* op) {
232-
if (op->value->dtype.lanes() > 1) {
232+
if (op->value->dtype.is_vector()) {
233233
if (static_cast<size_t>(op->value->dtype.lanes() * op->value->dtype.bytes()) >
234234
max_vector_bytes_) {
235235
std::stringstream s;

0 commit comments

Comments
 (0)