Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions examples/autodd/tilelang_buggy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,10 @@ def get_grid_size(self):
return grid_x, grid_y

def get_shared_memory_size(self):
return get_memory_requirements(
self.M, self.N, self.K, self.block_M, self.block_N, self.block_K
)
return get_memory_requirements(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K)

def validate(self):
return validate_parameters(
self.M, self.N, self.K, self.block_M, self.block_N, self.block_K
)
return validate_parameters(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K)


def create_reference_output(a, b, activation="relu"):
Expand All @@ -107,6 +103,7 @@ def benchmark_pytorch(M, N, K, num_iters=10, warmup=5):

# Benchmark
import time

start = time.time()
for _ in range(num_iters):
_ = a @ b
Expand Down
4 changes: 1 addition & 3 deletions examples/autodd/tilelang_minimized_expected.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class MatmulConfig:

def __init__(self, *args, **kwargs):
self.M = 1
self.N = 1
Expand All @@ -24,7 +23,6 @@ def __init__(self, *args, **kwargs):


def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs):

@T.prim_func
def matmul_kernel():
with T.Kernel():
Expand All @@ -45,7 +43,7 @@ def main(*args, **kwargs):
try:
run_kernel(config)
except Exception as e:
print(f'{e}')
print(f"{e}")


main()
11 changes: 1 addition & 10 deletions examples/flash_attention/example_gqa_bwd_tma_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,6 @@ def flash_bwd(
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)

T.annotate_layout(
{
dQ: make_dq_layout(dQ),
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
}
)

T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
Expand Down Expand Up @@ -387,7 +379,6 @@ def maybe_contiguous(x):
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do)

if ctx.use_atomic:
Expand All @@ -401,11 +392,11 @@ def maybe_contiguous(x):
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split_novarlen(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
Expand Down
11 changes: 1 addition & 10 deletions examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,6 @@ def flash_bwd(
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx

T.annotate_layout(
{
dQ: make_dq_layout(dQ),
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
}
)

T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared)

Expand Down Expand Up @@ -541,7 +533,6 @@ def maybe_contiguous(x):
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do, cu_seqlens_q)

if ctx.use_atomic:
Expand All @@ -565,7 +556,6 @@ def maybe_contiguous(x):
dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32)
kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split(
BATCH,
Expand All @@ -583,6 +573,7 @@ def maybe_contiguous(x):
num_stages=2,
groups=groups,
)
mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device)
dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device)
Expand Down
Loading
Loading