Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 131 additions & 6 deletions python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ class ForwardMetadata:
seq_lens: Optional[torch.Tensor] = None
actual_seq_lengths_q: Optional[torch.Tensor] = None

# prefix cache
prefix_lens: Optional[torch.Tensor] = None
flatten_prefix_block_tables: Optional[torch.Tensor] = None


class AscendAttnMaskBuilder:
def __init__(self, model_runner: ModelRunner, device, use_fia):
def __init__(self, model_runner: ModelRunner, device, use_fia, use_mla):
"""
Initialize the AscendAttnMaskBuilder class.

Expand Down Expand Up @@ -76,6 +80,13 @@ def __init__(self, model_runner: ModelRunner, device, use_fia):
self.mix_mask_cache = self.generate_attn_mask(mixed_chunk_cache_len, "mix")
self.mix_seq_len_cached = self.mix_mask_cache.shape[0]

if use_mla:
# Initialize RingMla mask
ringmla_mask_len = 512
self.ringmla_mask = self.generate_attn_mask(
ringmla_mask_len, "norm", torch.bfloat16
).to(self.device)

@staticmethod
def generate_mask_flag(max_seq_len):
"""
Expand Down Expand Up @@ -216,6 +227,7 @@ def __init__(self, model_runner: ModelRunner):
if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.q_head_dim = (
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
)
Expand All @@ -229,14 +241,16 @@ def __init__(self, model_runner: ModelRunner):
model_runner.server_args.speculative_num_draft_tokens
)
self.ascend_attn_mask_builder = AscendAttnMaskBuilder(
model_runner, self.device, self.use_fia
model_runner, self.device, self.use_fia, self.use_mla
)
self.mask, self.fia_mask, self.mtp_mask, self.mix_mask = (
self.ascend_attn_mask_builder.mask,
self.ascend_attn_mask_builder.fia_mask,
self.ascend_attn_mask_builder.mtp_mask,
self.ascend_attn_mask_builder.mix_mask_cache,
)
if self.use_mla:
self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask

def get_verify_buffers_to_fill_after_draft(self):
"""
Expand Down Expand Up @@ -279,6 +293,33 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

if forward_batch.forward_mode.is_target_verify():
self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens

if (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if only mla models need these

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that only MLA models need this. Added the self.use_mla check accordingly.

self.use_mla
and forward_batch.forward_mode.is_extend()
and sum(forward_batch.extend_prefix_lens_cpu) > 0
):
self.forward_metadata.prefix_lens = forward_batch.extend_prefix_lens.to(
"cpu"
)
seq_prefix_lens = self.forward_metadata.prefix_lens.tolist()
self.forward_metadata.flatten_prefix_block_tables = torch.empty(
0, dtype=torch.int32
).to(self.device)
for req_idx, seq_len in zip(
forward_batch.req_pool_indices.tolist(), seq_prefix_lens
):
req_indices = forward_batch.req_to_token_pool.req_to_token[req_idx]
req_prefix_block_tables = (
req_indices[:seq_len][:: self.page_size] // self.page_size
)
self.forward_metadata.flatten_prefix_block_tables = torch.cat(
(
self.forward_metadata.flatten_prefix_block_tables,
torch.flatten(req_prefix_block_tables),
)
)

if forward_batch.forward_mode.is_mixed():
self.mix_mask = self.ascend_attn_mask_builder.update_mask(
self.forward_metadata
Expand Down Expand Up @@ -590,15 +631,99 @@ def forward_extend(
enable_gqa=use_gqa,
causal=causal,
)
elif sum(forward_batch.extend_prefix_lens_cpu) > 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if this feature supports mtp

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feature can be used together with MTP. Since the KV cache in the MTP stage is relatively small, enabling prefix cache is not necessary for now.

q, k, v = [
data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v]
]
q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)

# 1st, compute extend tokens to get attn_output and attn_lse
num_tokens = q_nope.size(0)
attn_output = torch.zeros(
num_tokens,
layer.tp_q_head_num,
layer.v_head_dim,
dtype=q_nope.dtype,
device=q_nope.device,
)
attn_lse = torch.zeros(
layer.tp_q_head_num,
num_tokens,
dtype=torch.float32,
device=q_nope.device,
)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_rope,
k_nope=k_nope,
k_rope=k_rope,
value=v,
mask=self.ringmla_mask,
seqlen=self.forward_metadata.extend_seq_lens_cpu_int,
head_num=layer.tp_q_head_num,
kv_head_num=layer.tp_k_head_num,
pre_out=None,
prev_lse=None,
qk_scale=layer.scaling,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse,
)

# 2nd, load history kvcache(kv_a and k_pe) and calculate k_nope
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
kv_cached = torch.index_select(
k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables
)
k_rope_cached = torch.index_select(
v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables
).flatten(0, 1)

assert layer.kv_b_proj is not None
kv = layer.kv_b_proj(kv_cached)[0].view(
-1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim
)
k_nope, v = kv.split([self.qk_nope_head_dim, layer.v_head_dim], dim=-1)

# 3rd, compute history kv to attn_out
k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1)
seq_len = torch.stack(
[
self.forward_metadata.extend_seq_lens_cpu_int,
self.forward_metadata.prefix_lens,
]
)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_rope,
k_nope=k_nope,
k_rope=k_rope,
value=v,
mask=self.ringmla_mask,
seqlen=seq_len,
head_num=layer.tp_q_head_num,
kv_head_num=layer.tp_k_head_num,
pre_out=attn_output,
prev_lse=attn_lse,
qk_scale=layer.scaling,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse,
)
attn_output = attn_output.reshape(
[-1, layer.tp_q_head_num, layer.v_head_dim]
)
else:
assert (
layer.qk_head_dim != layer.v_head_dim
), "FIA only supports qk_head_dim != v_head_dim"

# Wait for the KV transfer to complete before performing attention computation.
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

num_token_padding = q.shape[0]
q, k, v = [
data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def move_indices(self, op: CacheOperation):
elif self.mem_pool_host.layout == "page_first_direct":
return host_indices, device_indices.cpu()
elif self.io_backend == "kernel_ascend":
return host_indices, device_indices
return host_indices, device_indices.cpu()
else:
raise ValueError(f"Unsupported io backend")

Expand Down
75 changes: 54 additions & 21 deletions python/sglang/srt/mem_cache/memory_pool_host.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import logging
import threading
from collections import defaultdict
from functools import wraps
from typing import Optional

Expand Down Expand Up @@ -41,8 +42,6 @@

logger = logging.getLogger(__name__)

SUPPORT_PIN_MEMORY = not _is_npu


def synchronized(func):
@wraps(func)
Expand All @@ -53,6 +52,45 @@ def wrapper(self, *args, **kwargs):
return wrapper


def alloc_with_host_register(
dims,
dtype: torch.dtype,
device: str,
pin_memory: bool,
) -> torch.Tensor:
"""
Allocate tensor and register host memory with cudaHostRegister.
CudaHostRegister only applies when pin_memory=True.
"""
buffer = torch.empty(dims, dtype=dtype, device=device)
if pin_memory:
torch.cuda.cudart().cudaHostRegister(
buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
)
return buffer


def alloc_with_pin_memory(
dims,
dtype: torch.dtype,
device: str,
pin_memory: bool,
) -> torch.Tensor:
"""
Allocate tensor using PyTorch's built-in pin_memory flag.
"""
buffer = torch.empty(dims, dtype=dtype, device=device, pin_memory=pin_memory)
return buffer


ALLOC_MEMORY_FUNCS = defaultdict(
lambda: alloc_with_host_register,
{
"npu": alloc_with_pin_memory,
},
)


class HostKVCache(abc.ABC):

def __init__(
Expand All @@ -68,7 +106,7 @@ def __init__(
self.device_pool = device_pool
self.page_size = page_size
self.layout = layout
self.pin_memory = pin_memory and SUPPORT_PIN_MEMORY
self.pin_memory = pin_memory
self.device = device

self.dtype = device_pool.store_dtype
Expand Down Expand Up @@ -266,15 +304,11 @@ def init_kv_buffer(self):
raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
self.layout_dim = self.token_stride_size * self.layer_num
buffer = torch.empty(
dims,
dtype=self.dtype,
device=self.device,

alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
buffer = alloc_func(
dims, dtype=self.dtype, device=self.device, pin_memory=self.pin_memory
)
if self.pin_memory:
torch.cuda.cudart().cudaHostRegister(
buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
)
return buffer

@property
Expand Down Expand Up @@ -675,15 +709,18 @@ def init_kv_buffer(self):
self.page_size,
1,
)
self.k_buffer = torch.empty(
alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
self.k_buffer = alloc_func(
(*base_dims, self.kv_lora_rank),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
self.v_buffer = torch.empty(
self.v_buffer = alloc_func(
(*base_dims, self.qk_rope_head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
# Return k_buffer to preserve original kv_buffer and data_refs init logic,
# though Ascend doesn't use these parameters.
Expand All @@ -694,15 +731,11 @@ def init_kv_buffer(self):
self.kv_lora_rank + self.qk_rope_head_dim
) * self.dtype.itemsize
self.layout_dim = self.token_stride_size * self.layer_num
buffer = torch.empty(
dims,
dtype=self.dtype,
device=self.device,

alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
buffer = alloc_func(
dims, dtype=self.dtype, device=self.device, pin_memory=self.pin_memory
)
if self.pin_memory:
torch.cuda.cudart().cudaHostRegister(
buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
)
return buffer

def load_to_device_per_layer(
Expand Down
Loading
Loading