diff --git a/python/sglang/srt/layers/attention/wave_backend.py b/python/sglang/srt/layers/attention/wave_backend.py index 18c193ade493..1d910082b5fe 100644 --- a/python/sglang/srt/layers/attention/wave_backend.py +++ b/python/sglang/srt/layers/attention/wave_backend.py @@ -93,7 +93,7 @@ def __init__( ): # Lazy import to avoid the initialization of cuda context # TODO: Switch to wave decode. - from sglang.srt.layers.attention.wave_ops.decode_attention import ( + from sglang.srt.layers.attention.triton_ops.decode_attention import ( decode_attention_fwd, ) from sglang.srt.layers.attention.wave_ops.extend_attention import ( @@ -210,12 +210,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): bs = kv_indptr.shape[0] - 1 attn_logits = torch.empty( - (self.max_kv_splits, bs, self.v_head_dim, self.num_head), + (bs, self.num_head, self.max_kv_splits, self.v_head_dim), dtype=torch.float32, device=self.device, ) attn_lse = torch.empty( - (self.max_kv_splits, bs, self.num_head), + (bs, self.num_head, self.max_kv_splits), dtype=torch.float32, device=self.device, ) diff --git a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py index 1e250341bbc9..4f15a23426f9 100644 --- a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py @@ -657,11 +657,10 @@ def decode_attention_wave( sm_scale, logit_cap=0.0, ): - mha = (q.shape[1] // v_buffer.shape[1]) == 1 + mha = (q.shape[1] // v_buffer.shape[2]) == 1 num_seqs, num_query_heads, head_size = q.shape - total_tokens, num_kv_heads, _ = k_buffer.shape - _, _, head_size_kv = v_buffer.shape - seq_len = total_tokens // num_seqs + _, _, num_kv_heads, _ = k_buffer.shape + _, _, _, head_size_kv = v_buffer.shape block_size = 32 shape = paged_decode_attention_shape( num_query_heads, @@ -670,12 +669,9 @@ def decode_attention_wave( head_size_kv, block_size, num_seqs, - seq_len, + k_buffer.shape[1], ) - k_buffer = k_buffer.view(num_seqs, seq_len, num_kv_heads, head_size) - v_buffer = v_buffer.view(num_seqs, seq_len, num_kv_heads, head_size_kv) - # Get the kernels (either compile or load from cache). if mha: mfma_variant = ( @@ -758,17 +754,37 @@ def decode_attention_fwd( logit_cap=0.0, ): assert max_kv_splits == attn_logits.shape[2] - decode_attention_wave( - q, - k_buffer, - v_buffer, - o, - req_to_token, - b_req_idx, - attn_logits, - attn_logits_max, - num_kv_splits, - max_kv_splits, - sm_scale, - logit_cap, - ) + kv_group_num = q.shape[1] // v_buffer.shape[2] + + if kv_group_num == 1: + # MHA + decode_attention_wave( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + attn_logits, + attn_logits_max, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap, + ) + else: + # GQA/MQA/MLA + decode_attention_wave( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + attn_logits, + attn_logits_max, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap, + ) diff --git a/test/srt/test_wave_attention_kernels.py b/test/srt/test_wave_attention_kernels.py index b4810985d7a9..b4549d3b3c9d 100644 --- a/test/srt/test_wave_attention_kernels.py +++ b/test/srt/test_wave_attention_kernels.py @@ -218,6 +218,8 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): sm_scale, ) + k_buffer = k_buffer.view(B, seq_len, H_KV, D) + v_buffer = v_buffer.view(B, seq_len, H_KV, D_V) attn_logits = torch.empty( (max_kv_splits, B, D_V, H_Q), dtype=torch.float32,