Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
import triton.language as tl

from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.utils import is_hip
from sglang.srt.utils import get_bool_env_var, is_hip

_is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
Comment thread
1am9trash marked this conversation as resolved.

if _use_aiter:
from aiter.ops.cache import cp_gather_indexer_k_quant_cache

if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
Expand Down Expand Up @@ -163,8 +167,52 @@ def triton(
class GetKAndS:
@classmethod
def execute(cls, *args, **kwargs):
if _use_aiter:
return cls.aiter(*args, **kwargs)
return cls.triton(*args, **kwargs)

@classmethod
def aiter(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
page_indices: torch.Tensor,
seq_len_tensor: torch.Tensor,
seq_len_sum: int,
max_seq_len: int,
):
from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype

page_size = pool.page_size
index_head_dim = pool.index_head_dim
quant_block_size = pool.quant_block_size
scale_elems = index_head_dim // quant_block_size

kv_cache = buf.view(-1, page_size, index_head_dim + scale_elems * 4).view(
fp8_dtype
)
dst_k = torch.empty(
(seq_len_sum, index_head_dim), dtype=torch.uint8, device=buf.device
)
dst_scale = torch.empty(
(seq_len_sum, scale_elems * 4), dtype=torch.uint8, device=buf.device
)

cu_seq_lens = torch.zeros(
seq_len_tensor.shape[0] + 1, dtype=torch.int32, device=buf.device
)
torch.cumsum(seq_len_tensor.to(torch.int32), dim=0, out=cu_seq_lens[1:])

cp_gather_indexer_k_quant_cache(
kv_cache,
dst_k.view(fp8_dtype),
dst_scale,
page_indices.to(torch.int32),
cu_seq_lens,
preshuffle=True,
)
return dst_k, dst_scale

@classmethod
def triton(
cls,
Expand Down Expand Up @@ -364,15 +412,14 @@ def _set_k_and_s_triton(
raise ValueError(
f"index_k_scale must be 1D or 2D, got shape {index_k_scale.shape}"
)
if _is_hip:
assert buf_numel_per_page == 1 * (128 + 4)
else:
assert buf_numel_per_page == 64 * (128 + 4)
assert buf_numel_per_page == page_size * (128 + 4)
assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
assert index_head_dim == 128
assert scale_dim == 1
if _is_hip:
assert page_size == 1
assert (
page_size % 16 == 0
), f"HIP preshuffle requires page_size to be a multiple of 16, got {page_size}"
else:
assert page_size == 64

Expand Down
51 changes: 25 additions & 26 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,13 @@ def _get_topk_paged(
page_size = forward_batch.token_to_kv_pool.page_size
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm
if _is_hip:
assert page_size == 1, "only support page size 1"
block_tables = metadata.get_page_table_1()
assert (
page_size % 16 == 0
), f"HIP preshuffle requires page_size to be a multiple of 16, got {page_size}"
else:
assert page_size == 64, "only support page size 64"
# NOTE(dark): this support extend/decode/decode+graph
block_tables = metadata.get_page_table_64()
# NOTE(dark): this support extend/decode/decode+graph
block_tables = metadata.get_page_table_64()

max_seq_len = block_tables.shape[1] * page_size
kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
Expand All @@ -462,17 +463,12 @@ def _get_topk_paged(
assert len(q_fp8.shape) == 3
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
assert len(kv_cache_fp8.shape) == 2
block_kv = 1 if _is_hip else 64
block_kv = page_size
num_heads_kv = 1
head_dim_with_sf = 132
if _is_hip:
kv_cache_fp8 = kv_cache_fp8.view(
-1, block_kv, num_heads_kv, head_dim_with_sf
)
else:
kv_cache_fp8 = kv_cache_fp8.view(
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
)
kv_cache_fp8 = kv_cache_fp8.view(
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
)
assert len(weights.shape) == 3
weights = weights.squeeze(2)

Expand All @@ -483,9 +479,8 @@ def _get_topk_paged(
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits

batch_size, next_n, heads, _ = q_fp8.shape
logits = torch.full(
logits = torch.empty(
(batch_size * next_n, max_seq_len),
float("-inf"),
device=q_fp8.device,
dtype=torch.float32,
)
Expand All @@ -497,7 +492,7 @@ def _get_topk_paged(
seqlens_32,
block_tables,
max_seq_len,
Preshuffle=False,
Preshuffle=True,
KVBlockSize=block_kv,
)
else:
Expand Down Expand Up @@ -561,7 +556,9 @@ def _get_topk_ragged(

page_size = forward_batch.token_to_kv_pool.page_size
if _is_hip:
assert page_size == 1, "only support page size 1"
assert (
page_size % 16 == 0
), f"HIP preshuffle requires page_size to be a multiple of 16, got {page_size}"
else:
assert page_size == 64, "only support page size 64"

Expand All @@ -572,10 +569,7 @@ def _get_topk_ragged(
)
weights = weights.squeeze(-1)

if _is_hip:
block_tables = metadata.get_page_table_1()
else:
block_tables = metadata.get_page_table_64()
block_tables = metadata.get_page_table_64()

assert (
forward_batch.seq_lens_cpu is not None
Expand Down Expand Up @@ -1031,19 +1025,24 @@ def _store_index_k_cache(
)
return

# Fast path: AITER fused quant + cache store (HIP, page_size=1)
# Fast path: AITER fused quant + cache store (HIP, preshuffle)
if _use_aiter:
page_size = forward_batch.token_to_kv_pool.page_size
buf = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
layer_id=layer_id
)
# Reshape from (num_pages, 132) uint8 to (num_pages, 1, 132) fp8
# to match kernel's (num_blocks, block_size, head_dim + scale_bytes) layout
kv_cache = buf.unsqueeze(1).view(fp8_dtype)
# Reshape from (num_pages, page_size*(128+4)) uint8 to (num_pages, page_size, 132) fp8
kv_cache = buf.view(-1, page_size, 132).view(fp8_dtype)
out_loc = forward_batch.out_cache_loc
if not out_loc.is_contiguous():
out_loc = out_loc.contiguous()
indexer_k_quant_and_cache(
key, kv_cache, out_loc, self.block_size, self.scale_fmt
key,
kv_cache,
out_loc,
self.block_size,
self.scale_fmt,
preshuffle=True,
)
return

Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,9 @@ def __init__(
assert index_head_dim == 128

if _is_hip:
assert self.page_size == 1
assert (
self.page_size % 16 == 0
), f"HIP preshuffle requires page_size to be a multiple of 16, got {self.page_size}"
else:
assert self.page_size == 64
with (
Expand Down
11 changes: 2 additions & 9 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,15 +1741,8 @@ def _handle_model_specific_adjustments(self):
f"attn_tp_size={self.tp_size}, attention weights will be sharded across {self.tp_size} ranks."
)

if is_hip():
self.page_size = 1
logger.warning(
"Setting page size to 1 for DeepSeek DSA on ROCm."
)
else:
# For CUDA GPU
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek DSA.")
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek DSA.")

import torch

Expand Down
Loading