diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 5536cf80c8f5..22d9021f286c 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -43,9 +43,13 @@ class ForwardMetadata: seq_lens: Optional[torch.Tensor] = None actual_seq_lengths_q: Optional[torch.Tensor] = None + # prefix cache + prefix_lens: Optional[torch.Tensor] = None + flatten_prefix_block_tables: Optional[torch.Tensor] = None + class AscendAttnMaskBuilder: - def __init__(self, model_runner: ModelRunner, device, use_fia): + def __init__(self, model_runner: ModelRunner, device, use_fia, use_mla): """ Initialize the AscendAttnMaskBuilder class. @@ -76,6 +80,13 @@ def __init__(self, model_runner: ModelRunner, device, use_fia): self.mix_mask_cache = self.generate_attn_mask(mixed_chunk_cache_len, "mix") self.mix_seq_len_cached = self.mix_mask_cache.shape[0] + if use_mla: + # Initialize RingMla mask + ringmla_mask_len = 512 + self.ringmla_mask = self.generate_attn_mask( + ringmla_mask_len, "norm", torch.bfloat16 + ).to(self.device) + @staticmethod def generate_mask_flag(max_seq_len): """ @@ -216,6 +227,7 @@ def __init__(self, model_runner: ModelRunner): if self.use_mla: self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim + self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim self.q_head_dim = ( self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim ) @@ -229,7 +241,7 @@ def __init__(self, model_runner: ModelRunner): model_runner.server_args.speculative_num_draft_tokens ) self.ascend_attn_mask_builder = AscendAttnMaskBuilder( - model_runner, self.device, self.use_fia + model_runner, self.device, self.use_fia, self.use_mla ) self.mask, self.fia_mask, self.mtp_mask, self.mix_mask = ( self.ascend_attn_mask_builder.mask, @@ -237,6 +249,8 @@ def __init__(self, model_runner: ModelRunner): self.ascend_attn_mask_builder.mtp_mask, self.ascend_attn_mask_builder.mix_mask_cache, ) + if self.use_mla: + self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask def get_verify_buffers_to_fill_after_draft(self): """ @@ -279,6 +293,33 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_target_verify(): self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens + + if ( + self.use_mla + and forward_batch.forward_mode.is_extend() + and sum(forward_batch.extend_prefix_lens_cpu) > 0 + ): + self.forward_metadata.prefix_lens = forward_batch.extend_prefix_lens.to( + "cpu" + ) + seq_prefix_lens = self.forward_metadata.prefix_lens.tolist() + self.forward_metadata.flatten_prefix_block_tables = torch.empty( + 0, dtype=torch.int32 + ).to(self.device) + for req_idx, seq_len in zip( + forward_batch.req_pool_indices.tolist(), seq_prefix_lens + ): + req_indices = forward_batch.req_to_token_pool.req_to_token[req_idx] + req_prefix_block_tables = ( + req_indices[:seq_len][:: self.page_size] // self.page_size + ) + self.forward_metadata.flatten_prefix_block_tables = torch.cat( + ( + self.forward_metadata.flatten_prefix_block_tables, + torch.flatten(req_prefix_block_tables), + ) + ) + if forward_batch.forward_mode.is_mixed(): self.mix_mask = self.ascend_attn_mask_builder.update_mask( self.forward_metadata @@ -590,15 +631,99 @@ def forward_extend( enable_gqa=use_gqa, causal=causal, ) + elif sum(forward_batch.extend_prefix_lens_cpu) > 0: + q, k, v = [ + data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v] + ] + q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) + k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) + + # 1st, compute extend tokens to get attn_output and attn_lse + num_tokens = q_nope.size(0) + attn_output = torch.zeros( + num_tokens, + layer.tp_q_head_num, + layer.v_head_dim, + dtype=q_nope.dtype, + device=q_nope.device, + ) + attn_lse = torch.zeros( + layer.tp_q_head_num, + num_tokens, + dtype=torch.float32, + device=q_nope.device, + ) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_rope, + k_nope=k_nope, + k_rope=k_rope, + value=v, + mask=self.ringmla_mask, + seqlen=self.forward_metadata.extend_seq_lens_cpu_int, + head_num=layer.tp_q_head_num, + kv_head_num=layer.tp_k_head_num, + pre_out=None, + prev_lse=None, + qk_scale=layer.scaling, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse, + ) + + # 2nd, load history kvcache(kv_a and k_pe) and calculate k_nope + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + kv_cached = torch.index_select( + k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ) + k_rope_cached = torch.index_select( + v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ).flatten(0, 1) + + assert layer.kv_b_proj is not None + kv = layer.kv_b_proj(kv_cached)[0].view( + -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim + ) + k_nope, v = kv.split([self.qk_nope_head_dim, layer.v_head_dim], dim=-1) + + # 3rd, compute history kv to attn_out + k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1) + seq_len = torch.stack( + [ + self.forward_metadata.extend_seq_lens_cpu_int, + self.forward_metadata.prefix_lens, + ] + ) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_rope, + k_nope=k_nope, + k_rope=k_rope, + value=v, + mask=self.ringmla_mask, + seqlen=seq_len, + head_num=layer.tp_q_head_num, + kv_head_num=layer.tp_k_head_num, + pre_out=attn_output, + prev_lse=attn_lse, + qk_scale=layer.scaling, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + calc_type="calc_type_default", + output=attn_output, + softmax_lse=attn_lse, + ) + attn_output = attn_output.reshape( + [-1, layer.tp_q_head_num, layer.v_head_dim] + ) else: assert ( layer.qk_head_dim != layer.v_head_dim ), "FIA only supports qk_head_dim != v_head_dim" - # Wait for the KV transfer to complete before performing attention computation. - forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - num_token_padding = q.shape[0] q, k, v = [ data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v] diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 29b5a17c3f5c..cfb069fc97b6 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -503,7 +503,7 @@ def move_indices(self, op: CacheOperation): elif self.mem_pool_host.layout == "page_first_direct": return host_indices, device_indices.cpu() elif self.io_backend == "kernel_ascend": - return host_indices, device_indices + return host_indices, device_indices.cpu() else: raise ValueError(f"Unsupported io backend") diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index d5df87c93e58..08ff838ee06e 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -1,6 +1,7 @@ import abc import logging import threading +from collections import defaultdict from functools import wraps from typing import Optional @@ -41,8 +42,6 @@ logger = logging.getLogger(__name__) -SUPPORT_PIN_MEMORY = not _is_npu - def synchronized(func): @wraps(func) @@ -53,6 +52,45 @@ def wrapper(self, *args, **kwargs): return wrapper +def alloc_with_host_register( + dims, + dtype: torch.dtype, + device: str, + pin_memory: bool, +) -> torch.Tensor: + """ + Allocate tensor and register host memory with cudaHostRegister. + CudaHostRegister only applies when pin_memory=True. + """ + buffer = torch.empty(dims, dtype=dtype, device=device) + if pin_memory: + torch.cuda.cudart().cudaHostRegister( + buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0 + ) + return buffer + + +def alloc_with_pin_memory( + dims, + dtype: torch.dtype, + device: str, + pin_memory: bool, +) -> torch.Tensor: + """ + Allocate tensor using PyTorch's built-in pin_memory flag. + """ + buffer = torch.empty(dims, dtype=dtype, device=device, pin_memory=pin_memory) + return buffer + + +ALLOC_MEMORY_FUNCS = defaultdict( + lambda: alloc_with_host_register, + { + "npu": alloc_with_pin_memory, + }, +) + + class HostKVCache(abc.ABC): def __init__( @@ -68,7 +106,7 @@ def __init__( self.device_pool = device_pool self.page_size = page_size self.layout = layout - self.pin_memory = pin_memory and SUPPORT_PIN_MEMORY + self.pin_memory = pin_memory self.device = device self.dtype = device_pool.store_dtype @@ -266,15 +304,11 @@ def init_kv_buffer(self): raise ValueError(f"Unsupported layout: {self.layout}") self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize self.layout_dim = self.token_stride_size * self.layer_num - buffer = torch.empty( - dims, - dtype=self.dtype, - device=self.device, + + alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device] + buffer = alloc_func( + dims, dtype=self.dtype, device=self.device, pin_memory=self.pin_memory ) - if self.pin_memory: - torch.cuda.cudart().cudaHostRegister( - buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0 - ) return buffer @property @@ -675,15 +709,18 @@ def init_kv_buffer(self): self.page_size, 1, ) - self.k_buffer = torch.empty( + alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device] + self.k_buffer = alloc_func( (*base_dims, self.kv_lora_rank), dtype=self.dtype, device=self.device, + pin_memory=self.pin_memory, ) - self.v_buffer = torch.empty( + self.v_buffer = alloc_func( (*base_dims, self.qk_rope_head_dim), dtype=self.dtype, device=self.device, + pin_memory=self.pin_memory, ) # Return k_buffer to preserve original kv_buffer and data_refs init logic, # though Ascend doesn't use these parameters. @@ -694,15 +731,11 @@ def init_kv_buffer(self): self.kv_lora_rank + self.qk_rope_head_dim ) * self.dtype.itemsize self.layout_dim = self.token_stride_size * self.layer_num - buffer = torch.empty( - dims, - dtype=self.dtype, - device=self.device, + + alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device] + buffer = alloc_func( + dims, dtype=self.dtype, device=self.device, pin_memory=self.pin_memory ) - if self.pin_memory: - torch.cuda.cudart().cudaHostRegister( - buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0 - ) return buffer def load_to_device_per_layer( diff --git a/test/srt/ascend/test_ascend_hicache_mla.py b/test/srt/ascend/test_ascend_hicache_mla.py new file mode 100644 index 000000000000..5e7c711e868d --- /dev/null +++ b/test/srt/ascend/test_ascend_hicache_mla.py @@ -0,0 +1,102 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + +TEST_MODEL_MATRIX = { + "/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": { + "accuracy": 0.34, + "latency": 1000, + "output_throughput": 6, + }, +} + + +class TestAscendMlaHicache(CustomTestCase): + + @classmethod + def setUpClass(cls): + cls.models = TEST_MODEL_MATRIX.keys() + cls.base_url = DEFAULT_URL_FOR_TEST + cls.url = urlparse(DEFAULT_URL_FOR_TEST) + cls.common_args = [ + "--trust-remote-code", + "--mem-fraction-static", + 0.8, + "--attention-backend", + "ascend", + "--quantization", + "modelslim", + "--tp-size", + 4, + "--enable-hierarchical-cache", + "--hicache-ratio", + 1.2, + ] + + def test_a_gsm8k(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing accuracy: {model} ===##") + + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *self.common_args, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{self.url.hostname}", + port=int(self.url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual( + metrics["accuracy"], + TEST_MODEL_MATRIX[model]["accuracy"], + ) + finally: + kill_process_tree(process.pid) + + def test_b_throughput(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing throughput: {model} ===##") + + output_throughput = run_bench_offline_throughput( + model, + [ + *self.common_args, + ], + ) + + print(f"##=== {model} throughput: {output_throughput} ===##") + + if is_in_ci(): + self.assertGreater( + output_throughput, + TEST_MODEL_MATRIX[model]["output_throughput"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cd6a95add1a8..3a36836cf85e 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -375,6 +375,7 @@ ], "per-commit-4-npu-a2": [ TestFile("ascend/test_ascend_mla_w8a8int8.py", 400), + TestFile("ascend/test_ascend_hicache_mla.py", 400), TestFile("ascend/test_ascend_tp4_bf16.py", 400), ], "per-commit-16-npu-a3": [