Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/attention/wave_ops/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
is_hip_ = is_hip()

logger = logging.getLogger(__name__)
import os
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))

# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
logger.warning(
Expand Down Expand Up @@ -709,6 +711,11 @@ def decode_attention_wave(
attn_logits,
attn_logits_max,
)
if dump_generated_mlir:
filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_qk.module_op.get_asm())


with tk.gen.TestLaunchContext(
hyperparams_1,
Expand All @@ -720,6 +727,11 @@ def decode_attention_wave(
use_scheduling_barriers=False,
):
mb_sv = phase_1(attn_logits, attn_logits_max, o)
if dump_generated_mlir:
filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_sv.module_op.get_asm())



def decode_attention_fwd(
Expand Down
101 changes: 101 additions & 0 deletions python/sglang/srt/layers/attention/wave_ops/prefill_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""

# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1

import torch
import math

import iree.turbine.kernel as tk
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils import (
get_default_run_config,
get_default_scheduling_params,
)
from iree.turbine.kernel.wave.constraints import MMAType

from iree.turbine.kernel.wave.templates.attention_common import AttentionShape
from iree.turbine.kernel.wave.templates.prefill_attention import (
get_prefill_attention_kernel,
)

import os
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))

def prefill_attention_wave(
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
):

# if not is_causal:
# raise NotImplementedError("non causal mask not supported yet on prefill_attention wave backend.")
shape = AttentionShape(
num_query_heads=q.shape[1],
num_kv_heads=k.shape[1],
head_size=q.shape[2],
head_size_kv=k.shape[2],
num_seqs=b_seq_len.shape[0],
max_seq_len=max_seq_len,
total_seq_len=q.shape[0],
)

assert shape.num_query_heads % shape.num_kv_heads == 0

output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
permuted_value = v.permute(1, 2, 0)
# Run the wave kernel.
mfma_variant =(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
(prefill, hyperparams) = get_prefill_attention_kernel(
shape,
mfma_variant,
q.shape,
k.shape,
permuted_value.shape,
output_shape,
)

hyperparams.update(get_default_scheduling_params())
config = get_default_run_config()

log2e = 1.44269504089
dk_sqrt = math.sqrt(1.0 / shape.head_size)

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
run=True,
run_bench=False,
run_config=config,
schedule=False,
use_scheduling_barriers=False,
):
# TODO: Add scaling of QK as part of kernel.
# TODO: Add variant of non-transposed V attention kernel.
mb = prefill(
q * dk_sqrt * log2e,
k,
permuted_value,
b_start_loc.to(torch.int32),
b_seq_len.to(torch.int32),
o,
)
if dump_generated_mlir:
shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
57 changes: 53 additions & 4 deletions test/srt/test_wave_attention_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
)


from sglang.srt.layers.attention.wave_ops.prefill_attention import (
prefill_attention_wave,
)
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)


class TestWaveAttention(unittest.TestCase):

def _set_all_seeds(self, seed):
Expand Down Expand Up @@ -42,7 +50,7 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):

# o will have the same shape as q
o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
o = torch.zeros(B, H_Q, D_V, dtype=torch.float32, device="cuda")
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")

req_to_token = torch.arange(
total_tokens, device="cuda", dtype=torch.int32
Expand Down Expand Up @@ -102,14 +110,14 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
)
print(cos_sim.item())
self.assertTrue(cos_sim.item() > 0.99)
self.assertTrue(torch.allclose(o, o_triton.to(torch.float32), atol=3e-2))
self.assertTrue(torch.allclose(o, o_triton, atol=3e-2))

def test_grouped_decode_attention(self):
# seq_lens = [5, 100, 128, 500]
seq_lens = [100, 128, 500]
seq_lens = [128,]
configs = [
# (2, 16, 16, 64, 64),
(2, 16, 1, 64, 64),
# (2, 16, 1, 64, 64), uncomment this
# (2, 64, 1, 13, 13),
(2, 128, 1, 80, 80),
# (2, 128, 2, 512, 512),
Expand All @@ -120,6 +128,47 @@ def test_grouped_decode_attention(self):
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)

def _test_context_attention_once(self, head_dim, is_causal):
# Set up a simple test case
dtype = torch.float16
num_heads = 4
seq_lens = [64, 128]
max_seq_len = max(seq_lens)

# Create random input tensors
q = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda")
k = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda")
v = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda")
o_triton = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda")
o = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=torch.float32, device="cuda")

# Create b_start_loc and b_seq_len tensors
b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
b_seq_len = torch.tensor(seq_lens, device="cuda")

context_attention_fwd(
q, k, v, o_triton, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
)
prefill_attention_wave(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_triton.to(torch.float32).flatten(), dim=0
)

print(cos_sim.item())
self.assertTrue(torch.allclose(o, o_triton.to(torch.float32), atol=3e-2))
self.assertTrue(cos_sim.item() > 1 - (1e-5))

def test_context_attention(self):
# head_dim = [128, 96, 80, 13]
# for is_causal in [False, True]:

head_dim = [128]

for dim in head_dim:
for is_causal in [False]:
self._test_context_attention_once(dim, is_causal)



if __name__ == "__main__":
unittest.main()