Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 32 additions & 132 deletions examples/flash_decoding/example_gqa_decode_varlen_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import argparse
import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune

torch.manual_seed(0)
tilelang.disable_cache()
Expand Down Expand Up @@ -198,7 +197,7 @@ def get_configs():
return configs


@autotune(configs=get_configs(), warmup=10, rep=10)
# @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(
batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128
Expand Down Expand Up @@ -438,127 +437,6 @@ def grid(META):
return O, S


def test_equal_seqlen_decode_main(args):
"""Test decode kernel with equal sequence lengths"""
print("Testing decode kernel with equal sequence lengths")

batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
k_seqlen = args.k_seqlen
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16

# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)

# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")

# Convert to varlen format for K, V
k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)

# Generate cumulative sequence lengths
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
max_seqlen_k = k_seqlen

print(f"q shape: {q.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")

num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)

# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
)
for i in range(batch_size):
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0

# Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]

if sink is None:
# Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]

sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]

# Compute attention score pooling
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True,
).to(torch.float16)

print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled)

max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch))
max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled))

print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!")


def test_varlen_decode_main(args):
"""Test decode kernel with variable sequence lengths"""
batch_size = args.batch_size
Expand Down Expand Up @@ -742,16 +620,23 @@ def test_varlen_decode_main(args):
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")

max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
max_diff_s_tl = torch.max(
torch.abs(
S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)]
)
)
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")

assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
f"Score mismatch: {max_diff_s_tl.item()}"
)
assert torch.allclose(
S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)],
attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)],
atol=1e-2,
rtol=1e-2,
), f"Score mismatch: {max_diff_s_tl.item()}"

print("✅ All tests passed!")

Expand Down Expand Up @@ -882,28 +767,43 @@ def speed_benchmark_decode_comparison(args):
print(f"Speedup: {(triton_time / tilelang_time):.3f}")


def main():
args = argparse.Namespace(
batch_size=1,
q_heads=32,
kv_heads=8,
k_seqlen=8192,
head_size=128,
block_size=128,
dtype=T.float16,
)
args.test_sink = True
args.test_varlen = True
args.dtype = T.float16
args.num_split = 1
test_varlen_decode_main(args)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
parser.add_argument("--block_size", type=int, default=64, help="Block size for computation")
parser.add_argument("--block_size", type=int, default=128, help="Block size for computation")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if block_size=64 is used elsewhere in the codebase
rg -nP --type=py 'block_size\s*=\s*64' examples/

Repository: tile-ai/tilelang

Length of output: 1496


🏁 Script executed:

#!/bin/bash
# Check the broader context of block_size usage
rg -nP --type=py 'block_size' examples/flash_decoding/ | head -20

Repository: tile-ai/tilelang

Length of output: 2950


🏁 Script executed:

#!/bin/bash
# Check the git history of this file to confirm the change from 64 to 128
git log -p --follow examples/flash_decoding/example_gqa_decode_varlen_logits.py | grep -A 5 -B 5 "block_size.*default" | head -30

Repository: tile-ai/tilelang

Length of output: 42


Parameter default changed from 64 to 128 inconsistent with codebase patterns.

The block_size default has been changed from 64 to 128, creating an inconsistency with the rest of the codebase where block_size=64 is the standard default across multiple modules (deepseek_nsa, deepseek_mla, and even the related example_gqa_decode_varlen_logits_paged.py file). Document this deviation and ensure the change is intentional and doesn't break existing workflows.

🤖 Prompt for AI Agents
In examples/flash_decoding/example_gqa_decode_varlen_logits.py around line 794,
the parser default for --block_size was changed from 64 to 128 which is
inconsistent with the rest of the codebase; either revert the default back to 64
to match deepseek_nsa/deepseek_mla and
example_gqa_decode_varlen_logits_paged.py, or if 128 is intentional, add a clear
inline comment and update documentation/examples to explain the deviation and
run a quick smoke test to ensure no workflows break.

parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type")
parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
args = parser.parse_args()
args.test_sink = True
args.test_varlen = False
args.test_varlen = True
args.dtype = T.float16
args.num_split = 1

if args.benchmark:
speed_benchmark_decode_comparison(args)
elif args.test_varlen:
test_varlen_decode_main(args)
else:
test_equal_seqlen_decode_main(args)
test_varlen_decode_main(args)
154 changes: 19 additions & 135 deletions examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,138 +199,6 @@ def flash_attn_with_attn_pool_decode_tilelang(
return O_tl, S_tl


def test_equal_seqlen_decode_main(args):
"""Test decode kernel with equal sequence lengths"""
print("Testing decode kernel with equal sequence lengths")

batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
k_seqlen = args.k_seqlen
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16

# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)

# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")

# Convert to varlen format for K, V
k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous()
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous()

# Generate cumulative sequence lengths
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
max_seqlen_k = k_seqlen

print(f"q shape: {q.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")

num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)

block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
for j in range(math.ceil(cur_seqlen / page_block_size)):
block_table[i, j] = block_cnt
block_cnt += 1
block_cnt = 0

# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
block_table=block_table,
)
for i in range(batch_size):
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0

# Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]

if sink is None:
# Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]

sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]

# Compute attention score pooling
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True,
).to(torch.float16)

print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled)

max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch))
max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled))

print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!")


def test_varlen_decode_main(args):
"""Test decode kernel with variable sequence lengths"""
batch_size = args.batch_size
Expand Down Expand Up @@ -651,6 +519,24 @@ def speed_benchmark_decode_comparison(args):
print(f"Speedup: {(triton_time / tilelang_time):.3f}")


def main():
args = argparse.Namespace(
batch_size=1,
q_heads=32,
kv_heads=8,
k_seqlen=8192,
head_size=128,
block_size=128,
dtype=T.float16,
)
args.test_sink = True
args.test_varlen = True
args.dtype = T.float16
args.num_split = 1
args.page_block_size = 128
test_varlen_decode_main(args)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
Expand All @@ -673,7 +559,5 @@ def speed_benchmark_decode_comparison(args):

if args.benchmark:
speed_benchmark_decode_comparison(args)
elif args.test_varlen:
test_varlen_decode_main(args)
else:
test_equal_seqlen_decode_main(args)
test_varlen_decode_main(args)
Loading
Loading