Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/sglang/srt/layers/attention/merge_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional, Tuple

import torch
from sgl_kernel import merge_state_v2

from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
from sglang.srt.utils import is_cuda

_is_cuda = is_cuda()


# Automatically fallback to the Triton kernel in some cases
# (e.g., for AMD GPUs, when the head dimension is not a multiple
# of 4 or 8, and in FP8 precision)
def _supported_dtypes(o: torch.Tensor) -> bool:
return o.dtype in [torch.float32, torch.half, torch.bfloat16]


def _supported_headdim(o: torch.Tensor) -> bool:
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if o.dtype == torch.float32:
return headdim % 4 == 0
return headdim % 8 == 0


def merge_state(
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output: Optional[torch.Tensor] = None,
output_lse: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if (
_is_cuda
and _supported_dtypes(prefix_output)
and _supported_headdim(prefix_output)
):
return merge_state_v2(
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
)
else:
# Fallback to Triton kernel
return merge_state_triton(
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
)
96 changes: 96 additions & 0 deletions python/sglang/srt/layers/attention/triton_ops/merge_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Optional, Tuple

import torch
import triton
import triton.language as tl


@triton.jit
def merge_state_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)

p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse

max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
out_se = tl.exp(p_lse) + tl.exp(s_lse)

if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)

head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)

p_scale = tl.exp(p_lse) / out_se
s_scale = tl.exp(s_lse) / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask,
)


def merge_state_triton(
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output: Optional[torch.Tensor] = None,
output_lse: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Avoid creating new tensors if they are already provided
if output is None:
output = torch.empty_like(prefix_output)
if output_lse is None:
output_lse = torch.empty_like(prefix_lse)

num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)

merge_state_kernel[(num_tokens, num_query_heads)](
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
padded_head_size,
output_lse is not None,
)
return output, output_lse
Loading