Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 321 additions & 0 deletions testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
import torch
import tilelang.testing
from tilelang import tvm as tvm
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func

tilelang.testing.set_random_seed(0)


@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
a_transposed=False,
b_transposed=True,
k_pack=1,
b_preshuffle=False,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

if in_dtype in {"float8_e4m3fnuz", "int8"}:
micro_size_k = 32

block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32

# for preshuffle_b, warp_layout = {1, 4}
if b_preshuffle:
block_row_warps = 1
block_col_warps = 4
warp_row_tiles = 128
warp_col_tiles = 32

chunk = 32 * k_pack

pack_size_k = micro_size_k * k_pack

shared_scope = "shared"
cache_write_shared = False

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

A_shape = (K, M) if a_transposed else (M, K)
if b_preshuffle:
B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y,
pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y,
pack_size_k, micro_size_y)
else:
B_shape = (N, K) if b_transposed else (K, N)
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
if b_preshuffle:
B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y,
pack_size_k) if b_transposed else (block_K // pack_size_k,
block_N // micro_size_y, pack_size_k,
micro_size_y)
else:
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size
local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

# MMA Wrapper to Auto Generate Code for MMA
mfma_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
k_pack=k_pack,
b_preshuffle=b_preshuffle,
)

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
})

# Improve L2 Cache
T.use_swizzle(panel_size=10)

T.clear(C_local)

for ko in T.Pipelined((K // block_K), num_stages=0):

# Load A into shared memory
if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, ko * block_K], A_shared)

# Load B into shared memory
if b_preshuffle:
if b_transposed:
for j, k, jj, kk in T.Parallel(block_N // micro_size_y,
block_K // pack_size_k, micro_size_y,
pack_size_k):
B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j,
ko * block_K // pack_size_k + k, jj, kk]
else:
for k, j, kk, jj in T.Parallel(block_K // pack_size_k,
block_N // micro_size_y, pack_size_k,
micro_size_y):
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k,
bx * block_N // micro_size_y + j, kk, jj]
else:
if b_transposed:
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)

for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):

# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)

# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)

# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local)

# Perform STMatrix
if cache_write_shared:
mfma_emitter.stmatrix(
C_local,
C_shared,
)

# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
else:
mfma_emitter.stmatrix(
C_local,
C,
pid_m=by,
pid_n=bx,
)

return main


def shuffle_weight(
x: torch.Tensor,
layout=(16, 32),
k_pack=1,
is_transpose=False,
) -> torch.Tensor:
IN, IK = layout
Comment on lines +207 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shuffle_weight function uses a hardcoded default layout=(16, 32). These values are tightly coupled with micro_size_y and micro_size_k from the tl_matmul kernel definition. This makes the function less flexible and could lead to errors if the kernel parameters change.

To improve modularity and reduce this coupling, consider passing micro_size_n and micro_size_k directly to the function instead of the layout tuple.

You would then update the call site in assert_tl_matmul_correctness like this:

if b_preshuffle:
    micro_size_k = 32 if in_dtype == "int8" else 16
    micro_size_y = 16
    B_preshuffle = shuffle_weight(
        B_preshuffle, 
        micro_size_n=micro_size_y, 
        micro_size_k=micro_size_k, 
        k_pack=k_pack, 
        is_transpose=b_transposed
    )
    kernel(A, B_preshuffle, C)
Suggested change
def shuffle_weight(
x: torch.Tensor,
layout=(16, 32),
k_pack=1,
is_transpose=False,
) -> torch.Tensor:
IN, IK = layout
def shuffle_weight(
x: torch.Tensor,
micro_size_n: int,
micro_size_k: int,
k_pack: int = 1,
is_transpose: bool = False,
) -> torch.Tensor:
IN, IK = micro_size_n, micro_size_k

BK = IK * k_pack
BN = IN

N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
assert N % BN == 0
assert K % BK == 0

x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
x = x.permute(0, 2, 1, 3)
return x.contiguous()
Comment on lines +217 to +223
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation for tensor dimensions in shuffle_weight.

The function assumes 2D input tensors but doesn't validate this assumption. Consider adding a check to ensure the input tensor has exactly 2 dimensions.

Apply this diff to add dimension validation:

 def shuffle_weight(
         x: torch.Tensor,
         layout=(16, 32),
         k_pack=1,
         is_transpose=False,
 ) -> torch.Tensor:
+    if x.ndim != 2:
+        raise ValueError(f"Expected 2D tensor, got {x.ndim}D tensor")
     IN, IK = layout
     BK = IK * k_pack
     BN = IN
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
assert N % BN == 0
assert K % BK == 0
x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
x = x.permute(0, 2, 1, 3)
return x.contiguous()
def shuffle_weight(
x: torch.Tensor,
layout=(16, 32),
k_pack=1,
is_transpose=False,
) -> torch.Tensor:
if x.ndim != 2:
raise ValueError(f"Expected 2D tensor, got {x.ndim}D tensor")
IN, IK = layout
BK = IK * k_pack
BN = IN
N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
assert N % BN == 0
assert K % BK == 0
x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
x = x.permute(0, 2, 1, 3)
return x.contiguous()
🤖 Prompt for AI Agents
In testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py around lines 217 to
223, the shuffle_weight code assumes a 2D input but does not validate that; add
an explicit check at the start of this block to ensure x.dim() == 2 (or x.ndim
== 2) and raise a clear ValueError if not (include actual ndim in the message),
then proceed with the existing asserts and reshaping — this prevents confusing
errors later when non-2D tensors are passed.



def assert_tl_matmul_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype="float32",
a_transposed=False,
b_transposed=True,
k_pack=1,
b_preshuffle=False):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack, b_preshuffle)
print(matmul)
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))

C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))

B_preshuffle = B
if b_preshuffle:
B_preshuffle = shuffle_weight(B_preshuffle, k_pack=k_pack, is_transpose=b_transposed)
kernel(A, B_preshuffle, C)
else:
kernel(A, B, C)

print(kernel.get_kernel_source())

profiler = kernel.get_profiler()

latency = profiler.do_bench()

# Ensure that the latency is not None
assert latency is not None

if a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix syntax error: missing dot operator.

There's a syntax error on line 276 where the dot operator is missing between A.T and to.

Apply this diff to fix the syntax error:

-        ref_c = torch.matmul(A.Tto(torch.float32),
+        ref_c = torch.matmul(A.T.to(torch.float32),
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
ref_c = torch.matmul(A.Tto(torch.float32),
ref_c = torch.matmul(A.T.to(torch.float32),
🤖 Prompt for AI Agents
In testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py around line 276,
there's a syntax error where the dot operator is missing between A.T and to;
replace the incorrect call `A.Tto(torch.float32,` with the correct chained
attribute call `A.T.to(torch.float32` so the transpose is followed by
`.to(...)`.

B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
Comment on lines +270 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a typo on line 276: A.Tto(torch.float32) should be A.T.to(torch.float32).

Although this code path is not exercised by the current tests (no test case sets a_transposed=True), this is a critical bug that should be fixed.

Additionally, this entire conditional block can be simplified to improve readability and maintainability, which would also help prevent such typos.

    A_ref = A.T if a_transposed else A
    B_ref = B.T if b_transposed else B
    ref_c = torch.matmul(A_ref.to(torch.float32), B_ref.to(torch.float32)).to(getattr(torch, out_dtype))


print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)

assert_tl_matmul_correctness(
128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)

assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
128,
256,
256,
"int8",
"int32",
b_transposed=False,
accum_dtype="int32",
k_pack=2,
b_preshuffle=True)


if __name__ == "__main__":
tilelang.testing.main()
Loading
Loading