diff --git a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py index 9e1c277378fa..6273e55ed743 100644 --- a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py @@ -43,6 +43,8 @@ is_hip_ = is_hip() logger = logging.getLogger(__name__) +import os +dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) # TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. logger.warning( @@ -709,6 +711,11 @@ def decode_attention_wave( attn_logits, attn_logits_max, ) + if dump_generated_mlir: + filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_qk.module_op.get_asm()) + with tk.gen.TestLaunchContext( hyperparams_1, @@ -720,6 +727,11 @@ def decode_attention_wave( use_scheduling_barriers=False, ): mb_sv = phase_1(attn_logits, attn_logits_max, o) + if dump_generated_mlir: + filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_sv.module_op.get_asm()) + def decode_attention_fwd( diff --git a/python/sglang/srt/layers/attention/wave_ops/prefill_attention.py b/python/sglang/srt/layers/attention/wave_ops/prefill_attention.py new file mode 100644 index 000000000000..2c669da67d0e --- /dev/null +++ b/python/sglang/srt/layers/attention/wave_ops/prefill_attention.py @@ -0,0 +1,101 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for prefill. +It supporst page size = 1. +""" + +# Adapted from +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 + +import torch +import math + +import iree.turbine.kernel as tk +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import ( + get_default_run_config, + get_default_scheduling_params, +) +from iree.turbine.kernel.wave.constraints import MMAType + +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape +from iree.turbine.kernel.wave.templates.prefill_attention import ( + get_prefill_attention_kernel, +) + +import os +dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) + +def prefill_attention_wave( + q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True +): + + # if not is_causal: + # raise NotImplementedError("non causal mask not supported yet on prefill_attention wave backend.") + shape = AttentionShape( + num_query_heads=q.shape[1], + num_kv_heads=k.shape[1], + head_size=q.shape[2], + head_size_kv=k.shape[2], + num_seqs=b_seq_len.shape[0], + max_seq_len=max_seq_len, + total_seq_len=q.shape[0], + ) + + assert shape.num_query_heads % shape.num_kv_heads == 0 + + output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv) + permuted_value = v.permute(1, 2, 0) + # Run the wave kernel. + mfma_variant =(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16) + (prefill, hyperparams) = get_prefill_attention_kernel( + shape, + mfma_variant, + q.shape, + k.shape, + permuted_value.shape, + output_shape, + ) + + hyperparams.update(get_default_scheduling_params()) + config = get_default_run_config() + + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / shape.head_size) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=False, + run_config=config, + schedule=False, + use_scheduling_barriers=False, + ): + # TODO: Add scaling of QK as part of kernel. + # TODO: Add variant of non-transposed V attention kernel. + mb = prefill( + q * dk_sqrt * log2e, + k, + permuted_value, + b_start_loc.to(torch.int32), + b_seq_len.to(torch.int32), + o, + ) + if dump_generated_mlir: + shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]] + filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) diff --git a/test/srt/test_wave_attention_kernels.py b/test/srt/test_wave_attention_kernels.py index 9ddc582ead95..97f71f250936 100644 --- a/test/srt/test_wave_attention_kernels.py +++ b/test/srt/test_wave_attention_kernels.py @@ -11,6 +11,14 @@ ) +from sglang.srt.layers.attention.wave_ops.prefill_attention import ( + prefill_attention_wave, +) +from sglang.srt.layers.attention.triton_ops.prefill_attention import ( + context_attention_fwd, +) + + class TestWaveAttention(unittest.TestCase): def _set_all_seeds(self, seed): @@ -42,7 +50,7 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): # o will have the same shape as q o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - o = torch.zeros(B, H_Q, D_V, dtype=torch.float32, device="cuda") + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") req_to_token = torch.arange( total_tokens, device="cuda", dtype=torch.int32 @@ -102,14 +110,14 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): ) print(cos_sim.item()) self.assertTrue(cos_sim.item() > 0.99) - self.assertTrue(torch.allclose(o, o_triton.to(torch.float32), atol=3e-2)) + self.assertTrue(torch.allclose(o, o_triton, atol=3e-2)) def test_grouped_decode_attention(self): # seq_lens = [5, 100, 128, 500] - seq_lens = [100, 128, 500] + seq_lens = [128,] configs = [ # (2, 16, 16, 64, 64), - (2, 16, 1, 64, 64), + # (2, 16, 1, 64, 64), uncomment this # (2, 64, 1, 13, 13), (2, 128, 1, 80, 80), # (2, 128, 2, 512, 512), @@ -120,6 +128,47 @@ def test_grouped_decode_attention(self): for B, H_Q, H_KV, D, D_V in configs: self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V) + def _test_context_attention_once(self, head_dim, is_causal): + # Set up a simple test case + dtype = torch.float16 + num_heads = 4 + seq_lens = [64, 128] + max_seq_len = max(seq_lens) + + # Create random input tensors + q = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") + k = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") + v = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") + o_triton = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") + o = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=torch.float32, device="cuda") + + # Create b_start_loc and b_seq_len tensors + b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda") + b_seq_len = torch.tensor(seq_lens, device="cuda") + + context_attention_fwd( + q, k, v, o_triton, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal + ) + prefill_attention_wave(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal) + cos_sim = torch.nn.functional.cosine_similarity( + o.flatten(), o_triton.to(torch.float32).flatten(), dim=0 + ) + + print(cos_sim.item()) + self.assertTrue(torch.allclose(o, o_triton.to(torch.float32), atol=3e-2)) + self.assertTrue(cos_sim.item() > 1 - (1e-5)) + + def test_context_attention(self): + # head_dim = [128, 96, 80, 13] + # for is_causal in [False, True]: + + head_dim = [128] + + for dim in head_dim: + for is_causal in [False]: + self._test_context_attention_once(dim, is_causal) + + if __name__ == "__main__": unittest.main()