Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 10 additions & 49 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def _matmul_launch_metadata(grid, kernel, args):
ret = {}
M, N, K = args["M"], args["N"], args["K"]
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
if "tiles_per_update" in args:
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]"
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
Expand Down Expand Up @@ -376,8 +374,7 @@ def matmul_tma_persistent(a, b):


@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
a_ptr, b_ptr, c_ptr, #
def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
M, N, K, #
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
Expand Down Expand Up @@ -417,7 +414,6 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #

tile_id = start_pid - NUM_SMS
ki = -1
ni = -1

pid_m = 0
pid_n = 0
Expand All @@ -427,36 +423,10 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
num_pid_in_group = GROUP_SIZE_M * num_pid_n

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Create an opaque value to prevent the descriptor creation from being
# hoisted out of the loop
zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1)

for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
ni += 1

# Simulate a grouped gemm
if ni == tiles_per_update:
a_desc = tl._experimental_make_tensor_descriptor(
a_ptr + zero,
shape=[M, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl._experimental_make_tensor_descriptor(
b_ptr + zero,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
c_desc = tl._experimental_make_tensor_descriptor(
c_ptr + zero,
shape=[M, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
ni = 0

tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
Expand All @@ -482,8 +452,7 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)


def matmul_descriptor_persistent(a, b, tiles_per_update):
# Autotuner does not work with TMA. Use manual config.
def matmul_descriptor_persistent(a, b):
configs = {
torch.float8_e4m3fn: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4,
Expand Down Expand Up @@ -513,7 +482,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):

grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_descriptor_persistent[grid](
tiles_per_update, #
a, b, c, #
M, N, K, #
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], #
Expand Down Expand Up @@ -570,7 +538,7 @@ def bench_fn(reps, warmup_reps, fn, *args):
fn(*args)


def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000):
def bench(K, dtype, reps=1000, warmup_reps=10000):
M = 8192
N = 8192
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
Expand All @@ -586,10 +554,10 @@ def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000):
bench_fn(reps, warmup_reps, matmul_persistent, a, b.T)
if supports_tma():
bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b)
bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b, tiles_per_update)
bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b)


def validate(M, N, K, dtype, tiles_per_update):
def validate(M, N, K, dtype):
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
b = b.T.contiguous()
Expand All @@ -599,7 +567,7 @@ def validate(M, N, K, dtype, tiles_per_update):
naive_result = matmul(a, b.T)
persistent_result = matmul_persistent(a, b.T)
tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None
descriptor_persistent_result = matmul_descriptor_persistent(a, b, tiles_per_update) if supports_tma() else None
descriptor_persistent_result = matmul_descriptor_persistent(a, b) if supports_tma() else None

if torch_result is not None:
naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16),
Expand All @@ -624,7 +592,7 @@ def validate(M, N, K, dtype, tiles_per_update):
if tma_persistent_result is not None:
print(f"TMA persistent: {naive_vs_tma_persistent} ", end="")
if descriptor_persistent_result is not None:
print(f"Device TMA persistent: {naive_vs_descriptor_persistent} ", end="")
print(f"Tensor descriptor persistent: {naive_vs_descriptor_persistent} ", end="")
print()


Expand All @@ -644,13 +612,6 @@ def show_profile(precision, profile_name):
parser.add_argument("-K", type=int, required=False, default=512)
parser.add_argument("--K_range", type=int, nargs=2)
parser.add_argument("--K_step", type=int, default=512)
parser.add_argument(
"--tiles_per_update",
type=int,
default=1,
help=
"Number of output tiles calculated for each update of the tma descriptor in matmul_descriptor_persistent_kernel",
)
parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16")
args = parser.parse_args()

Expand All @@ -666,11 +627,11 @@ def show_profile(precision, profile_name):

torch.manual_seed(0)

validate(32, 32, 32, dtype, args.tiles_per_update)
validate(8192, 8192, 512, dtype, args.tiles_per_update)
validate(32, 32, 32, dtype)
validate(8192, 8192, 512, dtype)

proton.start("matmul", hook="triton")
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench(K, dtype, args.tiles_per_update)
bench(K, dtype)
proton.finalize()
show_profile(args.prec, "matmul")