Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions examples/gemm_fp8/example_tilelang_gemm_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
from tilelang.utils import determine_fp8_type, determine_torch_fp8_type
import itertools


Expand All @@ -17,8 +18,9 @@ def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
fp8_dtype = determine_torch_fp8_type()
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
return [a, b]


Expand Down Expand Up @@ -53,7 +55,7 @@ def get_configs():
)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = T.float8_e4m3fnuz
dtype = determine_fp8_type()
accum_dtype = T.float32

@T.prim_func
Expand Down Expand Up @@ -104,8 +106,9 @@ def gemm_fp8_ss(

def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
fp8_dtype = determine_torch_fp8_type()
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
Expand Down
7 changes: 5 additions & 2 deletions examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tilelang.tileop.base import GemmWarpPolicy
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
from tilelang.utils import determine_fp8_type

tilelang.testing.set_random_seed(0)

Expand Down Expand Up @@ -45,12 +46,14 @@ def tl_matmul(
num_stages,
k_pack=2,
num_threads=256,
in_dtype=T.float8_e4m3fnuz,
in_dtype=None,
out_dtype=T.float32,
accum_dtype=T.float32,
a_transposed=False,
b_transposed=True,
):
if in_dtype is None:
in_dtype = determine_fp8_type()
b_preshuffle = True
warp_size = 64
num_warps = num_threads // warp_size
Expand Down Expand Up @@ -164,7 +167,7 @@ def shuffle_weight(


def assert_tl_matmul_correctness(M, N, K, k_pack=1, a_transposed=False, b_transposed=True):
in_dtype = T.float8_e4m3fnuz
in_dtype = determine_fp8_type()
out_dtype = T.float32
accum_dtype = T.float32
kernel = tl_matmul(
Expand Down
22 changes: 13 additions & 9 deletions examples/gemm_fp8/example_tilelang_gemm_fp8.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import tilelang
import tilelang.language as T
from tilelang.utils import determine_fp8_type


def calc_diff(x, y):
Expand Down Expand Up @@ -55,21 +56,24 @@ def test_gemm_fp8(M, N, K, dtype):


def main():
test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2)
test_gemm_fp8(1024, 1024, 1024, determine_fp8_type())
test_gemm_fp8(1024, 1024, 1024, determine_fp8_type("e5m2"))


def run_regression_perf():
M, N, K = 4096, 4096, 4096
dtype = "float8_e4m3"
dtype = determine_fp8_type()
kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
dtype = "float8_e5m2"
kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
if torch.version.hip is None:
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
dtype = determine_fp8_type("e5m2")
kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
latency_e4m3 = profiler_e4m3.do_bench()
return latency_e4m3


if __name__ == "__main__":
Expand Down
24 changes: 15 additions & 9 deletions examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import tilelang
import tilelang.language as T
from tilelang.utils import determine_fp8_type


@tilelang.jit(out_idx=[-1])
Expand Down Expand Up @@ -73,21 +74,26 @@ def test_gemm_fp8(M, N, K, dtype):


def main():
test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2)
test_gemm_fp8(1024, 1024, 8192, determine_fp8_type())
test_gemm_fp8(1024, 1024, 8192, determine_fp8_type("e5m2"))


def run_regression_perf():
M, N, K = 1024, 1024, 8192
dtype = "float8_e4m3"
dtype = determine_fp8_type()
kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
dtype = "float8_e5m2"
kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
if torch.version.hip is None:
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
else:
latency_e4m3 = profiler_e4m3.do_bench()
if torch.version.hip is None:
dtype = determine_fp8_type("e5m2")
kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
return latency_e4m3


if __name__ == "__main__":
Expand Down
105 changes: 63 additions & 42 deletions examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter
from tilelang.utils.tensor import map_torch_type
from tilelang.utils import determine_fp8_type

tilelang.testing.set_random_seed(0)

Expand Down Expand Up @@ -39,26 +39,17 @@ def tl_matmul(
assert in_dtype in [
T.float16,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.int8,
], "Currently only float16 and int8 are supported"
], "Currently only float16, float8, and int8 are supported"
assert out_dtype in [
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

is_float8 = in_dtype in [
T.float8_e4m3fn,
T.float8_e5m2,
T.float8_e4m3fn,
T.float8_e5m2fnuz,
]
if out_dtype == T.int32 or is_float8:
micro_size_k = 32

# This is a debug config
block_row_warps = 2
block_col_warps = 2
Expand All @@ -78,34 +69,51 @@ def tl_matmul(
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
is_hip = torch.version.hip is not None
# MMA Wrapper to Auto Generate Code for MMA/MFMA
if is_hip:
mma_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
else:
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)

micro_size_x = mma_emitter.M_DIM
micro_size_y = getattr(mma_emitter, "n_dim", getattr(mma_emitter, "N_DIM", micro_size_x))
micro_size_k = mma_emitter.k_dim
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
threads = mma_emitter.threads
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
local_size_c = mma_emitter.local_size_out
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols

@T.prim_func
def gemm_fp8_intrinsic(
Expand Down Expand Up @@ -158,7 +166,10 @@ def gemm_fp8_intrinsic(
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
if is_hip:
mma_emitter.mfma(A_local, B_local, C_local, ki)
else:
mma_emitter.mma(A_local, B_local, C_local)

# Perform STMatrix
mma_emitter.stmatrix(
Expand Down Expand Up @@ -192,7 +203,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
elif in_dtype in {
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
Expand All @@ -218,18 +234,23 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):


def main():
assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
e4m3_dtype = determine_fp8_type()
assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32)
e5m2_dtype = determine_fp8_type("e5m2")
assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32)


def run_regression_perf():
M, N, K = 4096, 4096, 4096
out_dtype, accum_dtype = "float32", "float32"
in_dtype = T.float8_e4m3fn
in_dtype = determine_fp8_type()
kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
print(kernel_e4m3.get_kernel_source())
profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
if torch.version.hip is None:
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
else:
latency_e4m3 = profiler_e4m3.do_bench()
return latency_e4m3


Expand Down
2 changes: 2 additions & 0 deletions src/target/codegen_hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float8_e4m3fnx8", "long"},
{"float8_e5m2fnuzx4", "fp8_e5_4_t"},
{"float8_e5m2fnuzx8", "long"},
{"float8_e5m2x4", "fp8_e5_4_t"},
{"float8_e5m2x8", "long"},
{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
Expand Down
Loading
Loading