Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
return dense_mask


@tilelang.jit(out_idx=[4])
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
Expand Down Expand Up @@ -193,9 +194,8 @@ def test_topk_sparse_attention():
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

# Run Triton kernel
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=[4])
# Run tilelang kernel
kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)

tilelang_output = kernel(q, k, v, block_mask)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float"
kv_group_num = heads // heads_kv

@tilelang.jit(out_idx=[-1])
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen,
max_selected_blocks):
shape_q = [batch, heads, dim]
Expand Down Expand Up @@ -203,7 +204,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):

self.block_H = 64

program = flashattn(batch, heads, heads_kv, dim, dim_v)(
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.symbolic("num_split"),
Expand All @@ -212,9 +213,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks"))

self.kernel = tilelang.compile(
program, out_idx=-1, target='cuda', execution_backend="cython")

props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count

Expand Down Expand Up @@ -308,20 +306,28 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
is_causal_or_local=True,
max_splits=128)

program = flashattn(batch, heads, heads_kv, dim, dim_v)(
# program = flashattn(batch, heads, heads_kv, dim, dim_v)(
# block_N=block_size,
# block_H=block_H,
# num_split=T.symbolic("num_split"),
# num_stages=2,
# threads=128,
# max_cache_seqlen=T.symbolic("max_cache_seqlen"),
# max_selected_blocks=T.symbolic("max_selected_blocks"))

glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
# kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.symbolic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks"))

glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
# print(kernel.get_kernel_source())

# output = kernel(query, key, value, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float"
kv_group_num = heads // heads_kv

@tilelang.jit(out_idx=[-1])
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
Expand Down Expand Up @@ -189,7 +190,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):

self.block_H = 64

program = flashattn(batch, heads, heads_kv, dim, dim_v)(
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.symbolic("num_split"),
Expand All @@ -198,9 +199,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks"))

self.kernel = tilelang.compile(
program, out_idx=-1, target='cuda', execution_backend="cython")

props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count

Expand Down Expand Up @@ -281,7 +279,7 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
is_causal_or_local=True,
max_splits=128)

program = flashattn(batch, heads, heads_kv, dim, dim_v)(
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.symbolic("num_split"),
Expand All @@ -293,7 +291,6 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
# print(kernel.get_kernel_source())

output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
Expand Down
5 changes: 2 additions & 3 deletions examples/blocksparse_gemm/example_blocksparse_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def kernel(block_M=None,
# Run the tuning process
return autotuner.run(warmup=3, rep=20)


@tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M,
N,
K,
Expand Down Expand Up @@ -211,10 +211,9 @@ def main():
print(f"Best Kernel Latency: {best_latency:.6f} ms")
print(f"Reference Latency: {ref_latency:.6f} ms")
else:
func = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
DEFAULT_ENABLE_RASTERIZATION)
kernel = tilelang.compile(func, out_idx=-1)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")

Expand Down
10 changes: 2 additions & 8 deletions examples/cast/example_group_per_split_token_cast_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
dtype = "bfloat16"
accum_dtype = "float"


@tilelang.jit(out_idx=[2, 3])
def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
group_size = 128
fp8_min = -448.0
Expand Down Expand Up @@ -179,13 +179,7 @@ def main():
print("batch_sizes:", batch_sizes)
print("M_max:", M_max)

program = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
kernel = tilelang.compile(
program,
out_idx=[2, 3],
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
print(kernel.get_kernel_source())
# profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)

Expand Down
10 changes: 2 additions & 8 deletions examples/cast/example_per_token_cast_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

tilelang.disable_cache()


@tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m):
dtype = "float"
group_size = 128
Expand Down Expand Up @@ -83,13 +83,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

def main():
M, N, blk_m = 8192, 8192, 8
program = per_token_cast_to_fp8(M, N, blk_m)
kernel = tilelang.compile(
program,
out_idx=[1, 2],
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
kernel = per_token_cast_to_fp8(M, N, blk_m)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)

Expand Down
5 changes: 2 additions & 3 deletions examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

tilelang.testing.set_random_seed(42)


@tilelang.jit(out_idx=[2])
def tl_gemm(
M,
N,
Expand Down Expand Up @@ -147,8 +147,7 @@ def calc_diff(x, y):


def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype):
gemm = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
kernel = TL.compile(gemm, out_idx=[])
kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()

# src_code is the generated cuda source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

tilelang.disable_cache()


@tilelang.jit(out_idx=[6])
def flashmla_decode(batch,
heads,
kv_head_num,
Expand Down Expand Up @@ -290,9 +290,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
BLOCK_H = 64
num_split = 4

program = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H,
num_split)
kernel = tilelang.compile(program, out_idx=[6])
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors)
Expand Down
3 changes: 1 addition & 2 deletions examples/deepseek_mla/benchmark_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,8 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size

out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size)
kernel = tilelang.compile(program, out_idx=[8])

def flash_mla_tilelang():
out = kernel(
Expand Down
5 changes: 2 additions & 3 deletions examples/deepseek_mla/example_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from einops import rearrange, einsum
import argparse


@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
Expand Down Expand Up @@ -289,8 +289,7 @@ def main():
BLOCK_H = 64
num_split = 1

program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
Expand Down
5 changes: 2 additions & 3 deletions examples/deepseek_mla/example_mla_decode_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tilelang.profiler import do_bench
import math


@tilelang.jit(out_idx=[8])
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
block_size):
scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e)
Expand Down Expand Up @@ -323,9 +323,8 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s

out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size)
kernel = tilelang.compile(program, out_idx=[8])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)

def flash_mla_tilelang():
Expand Down
5 changes: 2 additions & 3 deletions examples/deepseek_mla/example_mla_decode_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from einops import rearrange, einsum
import argparse


@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
Expand Down Expand Up @@ -209,8 +209,7 @@ def main():
BLOCK_H = 64
num_split = 2

program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from einops import rearrange, einsum
import argparse


@tilelang.jit(out_idx=[-1])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
Expand Down Expand Up @@ -148,9 +148,7 @@ def ref_program(q, q_pe, kv, k_pe):
BLOCK_N = 64
BLOCK_H = 64

program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H)
print(program)
kernel = tilelang.compile(program, out_idx=-1)
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
Expand Down
5 changes: 2 additions & 3 deletions examples/deepseek_nsa/example_tilelang_nsa_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

tilelang.testing.set_random_seed(42)


@tilelang.jit(out_idx=[-1])
def native_sparse_attention(
batch,
heads,
Expand Down Expand Up @@ -132,7 +132,7 @@ def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16
groups = HQ // H
SEQ_LEN_Q = 1
program = native_sparse_attention(
kernel = native_sparse_attention(
batch=B,
heads=HQ,
seq_len=SEQ_LEN,
Expand All @@ -142,7 +142,6 @@ def main():
selected_blocks=S,
)

kernel = tilelang.compile(program, out_idx=-1)
Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
Expand Down
5 changes: 2 additions & 3 deletions examples/deepseek_nsa/example_tilelang_nsa_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

tilelang.testing.set_random_seed(0)


@tilelang.jit(out_idx=[-1])
def native_sparse_attention(batch,
heads,
seq_len,
Expand Down Expand Up @@ -130,7 +130,7 @@ def native_sparse_attention(
def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1

program = native_sparse_attention(
kernel = native_sparse_attention(
batch=B,
heads=HQ,
seq_len=SEQ_LEN,
Expand All @@ -141,7 +141,6 @@ def main():
selected_blocks=S,
scale=scale,
)
kernel = tilelang.compile(program, out_idx=-1)
print(kernel.get_kernel_source())
torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
Expand Down
Loading
Loading