Skip to content

Commit 3218fac

Browse files
committed
some clean up
1 parent 7a235b6 commit 3218fac

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

python/tvm/tir/tensor_intrin/cuda.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,17 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
7979
shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 8 * (
8080
tx // HALF_WARP_expr
8181
)
82-
83-
elif k_dim == 32:
84-
assert dtype == "int8"
82+
else:
83+
assert (
84+
k_dim == 32 and dtype == "int8"
85+
), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"
8586

8687
if ldmatrix_col_major:
8788
index_map = shared_32x16_to_ldmatrix_32x16_layout
88-
shared_offset = (
89-
lambda _, stride: stride
90-
) # dummy offset, ldmatrix cannot be used for int8 + trans case
89+
# A dummy offset, ldmatrix cannot be used for int8 + trans case.
90+
# We still use the ldmatrix intrinsic, but lower it to a manual loop in the codegen.
91+
# Only the stride information is required.
92+
shared_offset = lambda _, stride: stride
9193
elif is_b and transposed:
9294
index_map = shared_16x32_to_ldmatrix_32x16_layout
9395
shared_offset = (
@@ -99,9 +101,6 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
99101
index_map = shared_16x32_to_ldmatrix_32x16_layout
100102
shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 16)
101103

102-
else:
103-
assert False, "Unsupported k dim"
104-
105104
assert index_map and shared_offset
106105

107106
if is_b and not transposed:

src/target/source/codegen_cuda.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
821821
std::string smem_ptr = this->PrintExpr(op->args[5]);
822822

823823
if (trans && op->dtype.bits() == 8) {
824+
// Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an
825+
// int8 matrix.
824826
std::string smem_stride = this->PrintExpr(op->args[6]);
825827
ICHECK(num == 4);
826828
os << "for (int i = 0; i < 16; ++i) {\n";

0 commit comments

Comments
 (0)