diff --git a/python/sglang/srt/layers/attention/wave_backend.py b/python/sglang/srt/layers/attention/wave_backend.py index 1d910082b5fe..18c193ade493 100644 --- a/python/sglang/srt/layers/attention/wave_backend.py +++ b/python/sglang/srt/layers/attention/wave_backend.py @@ -93,7 +93,7 @@ def __init__( ): # Lazy import to avoid the initialization of cuda context # TODO: Switch to wave decode. - from sglang.srt.layers.attention.triton_ops.decode_attention import ( + from sglang.srt.layers.attention.wave_ops.decode_attention import ( decode_attention_fwd, ) from sglang.srt.layers.attention.wave_ops.extend_attention import ( @@ -210,12 +210,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): bs = kv_indptr.shape[0] - 1 attn_logits = torch.empty( - (bs, self.num_head, self.max_kv_splits, self.v_head_dim), + (self.max_kv_splits, bs, self.v_head_dim, self.num_head), dtype=torch.float32, device=self.device, ) attn_lse = torch.empty( - (bs, self.num_head, self.max_kv_splits), + (self.max_kv_splits, bs, self.num_head), dtype=torch.float32, device=self.device, ) diff --git a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py index 4f15a23426f9..1e250341bbc9 100644 --- a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py @@ -657,10 +657,11 @@ def decode_attention_wave( sm_scale, logit_cap=0.0, ): - mha = (q.shape[1] // v_buffer.shape[2]) == 1 + mha = (q.shape[1] // v_buffer.shape[1]) == 1 num_seqs, num_query_heads, head_size = q.shape - _, _, num_kv_heads, _ = k_buffer.shape - _, _, _, head_size_kv = v_buffer.shape + total_tokens, num_kv_heads, _ = k_buffer.shape + _, _, head_size_kv = v_buffer.shape + seq_len = total_tokens // num_seqs block_size = 32 shape = paged_decode_attention_shape( num_query_heads, @@ -669,9 +670,12 @@ def decode_attention_wave( head_size_kv, block_size, num_seqs, - k_buffer.shape[1], + seq_len, ) + k_buffer = k_buffer.view(num_seqs, seq_len, num_kv_heads, head_size) + v_buffer = v_buffer.view(num_seqs, seq_len, num_kv_heads, head_size_kv) + # Get the kernels (either compile or load from cache). if mha: mfma_variant = ( @@ -754,37 +758,17 @@ def decode_attention_fwd( logit_cap=0.0, ): assert max_kv_splits == attn_logits.shape[2] - kv_group_num = q.shape[1] // v_buffer.shape[2] - - if kv_group_num == 1: - # MHA - decode_attention_wave( - q, - k_buffer, - v_buffer, - o, - req_to_token, - b_req_idx, - attn_logits, - attn_logits_max, - num_kv_splits, - max_kv_splits, - sm_scale, - logit_cap, - ) - else: - # GQA/MQA/MLA - decode_attention_wave( - q, - k_buffer, - v_buffer, - o, - req_to_token, - b_req_idx, - attn_logits, - attn_logits_max, - num_kv_splits, - max_kv_splits, - sm_scale, - logit_cap, - ) + decode_attention_wave( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + attn_logits, + attn_logits_max, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap, + ) diff --git a/test/srt/test_wave_attention_kernels.py b/test/srt/test_wave_attention_kernels.py index b4549d3b3c9d..b4810985d7a9 100644 --- a/test/srt/test_wave_attention_kernels.py +++ b/test/srt/test_wave_attention_kernels.py @@ -218,8 +218,6 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): sm_scale, ) - k_buffer = k_buffer.view(B, seq_len, H_KV, D) - v_buffer = v_buffer.view(B, seq_len, H_KV, D_V) attn_logits = torch.empty( (max_kv_splits, B, D_V, H_Q), dtype=torch.float32,