@@ -79,15 +79,17 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
7979 shared_offset = lambda tx , stride : stride * (tx % HALF_WARP_expr ) + 8 * (
8080 tx // HALF_WARP_expr
8181 )
82-
83- elif k_dim == 32 :
84- assert dtype == "int8"
82+ else :
83+ assert (
84+ k_dim == 32 and dtype == "int8"
85+ ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"
8586
8687 if ldmatrix_col_major :
8788 index_map = shared_32x16_to_ldmatrix_32x16_layout
88- shared_offset = (
89- lambda _ , stride : stride
90- ) # dummy offset, ldmatrix cannot be used for int8 + trans case
89+ # A dummy offset, ldmatrix cannot be used for int8 + trans case.
90+ # We still use the ldmatrix intrinsic, but lower it to a manual loop in the codegen.
91+ # Only the stride information is required.
92+ shared_offset = lambda _ , stride : stride
9193 elif is_b and transposed :
9294 index_map = shared_16x32_to_ldmatrix_32x16_layout
9395 shared_offset = (
@@ -99,9 +101,6 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
99101 index_map = shared_16x32_to_ldmatrix_32x16_layout
100102 shared_offset = lambda tx , stride : stride * (tx % 16 ) + 16 * (tx // 16 )
101103
102- else :
103- assert False , "Unsupported k dim"
104-
105104 assert index_map and shared_offset
106105
107106 if is_b and not transposed :
0 commit comments