From ea50be02ff59758c8f51de80c2cb41829af577b7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 7 May 2025 21:44:41 +0200 Subject: [PATCH 1/2] align shapes Signed-off-by: Ivan Butygin --- python/sglang/srt/layers/attention/wave_backend.py | 6 +++--- .../srt/layers/attention/wave_ops/decode_attention.py | 8 ++++++-- test/srt/test_wave_attention_kernels.py | 2 -- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/attention/wave_backend.py b/python/sglang/srt/layers/attention/wave_backend.py index 1d910082b5fe..18c193ade493 100644 --- a/python/sglang/srt/layers/attention/wave_backend.py +++ b/python/sglang/srt/layers/attention/wave_backend.py @@ -93,7 +93,7 @@ def __init__( ): # Lazy import to avoid the initialization of cuda context # TODO: Switch to wave decode. - from sglang.srt.layers.attention.triton_ops.decode_attention import ( + from sglang.srt.layers.attention.wave_ops.decode_attention import ( decode_attention_fwd, ) from sglang.srt.layers.attention.wave_ops.extend_attention import ( @@ -210,12 +210,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): bs = kv_indptr.shape[0] - 1 attn_logits = torch.empty( - (bs, self.num_head, self.max_kv_splits, self.v_head_dim), + (self.max_kv_splits, bs, self.v_head_dim, self.num_head), dtype=torch.float32, device=self.device, ) attn_lse = torch.empty( - (bs, self.num_head, self.max_kv_splits), + (self.max_kv_splits, bs, self.num_head), dtype=torch.float32, device=self.device, ) diff --git a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py index eb60e865e6e0..3889f3f465ff 100644 --- a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py @@ -659,8 +659,9 @@ def decode_attention_wave( ): mha = (q.shape[1] // v_buffer.shape[2]) == 1 num_seqs, num_query_heads, head_size = q.shape - _, _, num_kv_heads, _ = k_buffer.shape - _, _, _, head_size_kv = v_buffer.shape + total_tokens, num_kv_heads, _ = k_buffer.shape + _, _, head_size_kv = v_buffer.shape + seq_len = total_tokens // num_seqs block_size = 32 shape = paged_decode_attention_shape( num_query_heads, @@ -672,6 +673,9 @@ def decode_attention_wave( k_buffer.shape[1], ) + k_buffer = k_buffer.view(num_seqs, seq_len, num_kv_heads, head_size) + v_buffer = v_buffer.view(num_seqs, seq_len, num_kv_heads, head_size_kv) + # Get the kernels (either compile or load from cache). if mha: mfma_variant = ( diff --git a/test/srt/test_wave_attention_kernels.py b/test/srt/test_wave_attention_kernels.py index b4549d3b3c9d..b4810985d7a9 100644 --- a/test/srt/test_wave_attention_kernels.py +++ b/test/srt/test_wave_attention_kernels.py @@ -218,8 +218,6 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): sm_scale, ) - k_buffer = k_buffer.view(B, seq_len, H_KV, D) - v_buffer = v_buffer.view(B, seq_len, H_KV, D_V) attn_logits = torch.empty( (max_kv_splits, B, D_V, H_Q), dtype=torch.float32, From 95643c919d00bcaa14df435dad1e5e2e3d2efaae Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 7 May 2025 21:55:11 +0200 Subject: [PATCH 2/2] fix Signed-off-by: Ivan Butygin --- .../attention/wave_ops/decode_attention.py | 52 ++++++------------- 1 file changed, 16 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py index 3889f3f465ff..f2f513cecc25 100644 --- a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py @@ -657,7 +657,7 @@ def decode_attention_wave( sm_scale, logit_cap=0.0, ): - mha = (q.shape[1] // v_buffer.shape[2]) == 1 + mha = (q.shape[1] // v_buffer.shape[1]) == 1 num_seqs, num_query_heads, head_size = q.shape total_tokens, num_kv_heads, _ = k_buffer.shape _, _, head_size_kv = v_buffer.shape @@ -670,7 +670,7 @@ def decode_attention_wave( head_size_kv, block_size, num_seqs, - k_buffer.shape[1], + seq_len, ) k_buffer = k_buffer.view(num_seqs, seq_len, num_kv_heads, head_size) @@ -764,37 +764,17 @@ def decode_attention_fwd( logit_cap=0.0, ): assert max_kv_splits == attn_logits.shape[2] - kv_group_num = q.shape[1] // v_buffer.shape[2] - - if kv_group_num == 1: - # MHA - decode_attention_wave( - q, - k_buffer, - v_buffer, - o, - req_to_token, - b_req_idx, - attn_logits, - attn_logits_max, - num_kv_splits, - max_kv_splits, - sm_scale, - logit_cap, - ) - else: - # GQA/MQA/MLA - decode_attention_wave( - q, - k_buffer, - v_buffer, - o, - req_to_token, - b_req_idx, - attn_logits, - attn_logits_max, - num_kv_splits, - max_kv_splits, - sm_scale, - logit_cap, - ) + decode_attention_wave( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + attn_logits, + attn_logits_max, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap, + )