From 96d7bdfb7cdf8f0543c077d86b6abecb315b5fc0 Mon Sep 17 00:00:00 2001 From: silencejade <13120475055@163.com> Date: Thu, 2 Apr 2026 16:09:53 +0800 Subject: [PATCH 1/2] adpat move_intermediate_cache for sglang prefix + mtp --- .../mamba/mamba_state_update_triton.py | 23 +++++++++++-------- .../sgl_kernel_npu/test_mamba_state_update.py | 19 +++++++++------ 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py b/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py index e089dfc82..14b30a7a6 100644 --- a/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py +++ b/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py @@ -21,7 +21,8 @@ def move_cache_dynamic_last_kernel_h_block( dst_cache_ptr, src_cache_ptr, - valid_indices_ptr, + dst_indices_ptr, + src_indices_ptr, last_steps_ptr, layer_stride, size_stride, @@ -39,7 +40,8 @@ def move_cache_dynamic_last_kernel_h_block( valid_id = tl.program_id(0) # Load actual indices - valid_idx_val = tl.load(valid_indices_ptr + valid_id) + dst_idx_val = tl.load(dst_indices_ptr + valid_id) + src_idx_val = tl.load(src_indices_ptr + valid_id) last_step_val = tl.load(last_steps_ptr + valid_id) if last_step_val < 0: return @@ -52,12 +54,12 @@ def move_cache_dynamic_last_kernel_h_block( src_base_addr = ( src_cache_ptr + tl.cast(l, tl.int64) * layer_stride - + tl.cast(valid_idx_val, tl.int64) * size_stride + + tl.cast(src_idx_val, tl.int64) * size_stride ) dst_base_addr = ( dst_cache_ptr + tl.cast(l, tl.int64) * dst_layer_stride - + tl.cast(valid_idx_val, tl.int64) * dst_size_stride + + tl.cast(dst_idx_val, tl.int64) * dst_size_stride ) src_addr = src_base_addr + tl.cast(last_step_val, tl.int64) * draft_stride @@ -84,7 +86,8 @@ def move_cache_dynamic_last_kernel_h_block( def move_intermediate_cache( ssm_states, intermediate_state_cache, - valid_tensor, + dst_indices_tensor, + src_indices_tensor, last_steps_tensor, h_block_size=2, ): @@ -94,7 +97,8 @@ def move_intermediate_cache( Args: ssm_states: Destination SSM states tensor intermediate_state_cache: Source intermediate state cache - valid_tensor: Valid indices tensor + dst_indices_tensor: Valid destination indices tensor + src_indices_tensor: Valid source indices tensor last_steps_tensor: Last steps tensor h_block_size: Block size for h dimension processing """ @@ -109,15 +113,16 @@ def move_intermediate_cache( dst_layer_stride, dst_size_stride = int(ssm_states.stride()[0]), int( ssm_states.stride()[1] ) - assert len(valid_tensor) == len(last_steps_tensor), "Lengths must match" + assert len(dst_indices_tensor) == len(last_steps_tensor), "Lengths must match" # Grid: one thread per valid index - grid = (len(valid_tensor),) + grid = (len(dst_indices_tensor),) move_cache_dynamic_last_kernel_h_block[grid]( dst_cache_ptr=ssm_states, src_cache_ptr=intermediate_state_cache, - valid_indices_ptr=valid_tensor, + dst_indices_ptr=dst_indices_tensor, + src_indices_ptr=src_indices_tensor, last_steps_ptr=last_steps_tensor, layer_stride=layer_stride, size_stride=size_stride, diff --git a/tests/python/sgl_kernel_npu/test_mamba_state_update.py b/tests/python/sgl_kernel_npu/test_mamba_state_update.py index fa1f48539..2cb60b9d6 100644 --- a/tests/python/sgl_kernel_npu/test_mamba_state_update.py +++ b/tests/python/sgl_kernel_npu/test_mamba_state_update.py @@ -24,6 +24,7 @@ def get_err_ratio(x, y): def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): abs_atol = get_abs_err(ref, tri) msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" + print(msg) error_rate = get_err_ratio(ref, tri) if abs_atol <= err_atol: return @@ -129,24 +130,28 @@ def test_move_intermediate_cache( dtype: torch.dtype, ): torch.manual_seed(42) - + # prepare input data dst_cache = torch.randn(L, S, H, V, K, device=device, dtype=dtype) dst_cache_clone = dst_cache.clone() src_cache = torch.randn(L, S, D, H, V, K, device=device, dtype=dtype) + # prepare input data population = range(S) valid_indices = random.sample(population, num_valid) last_step_pos = [random.randint(0, D - 1) for _ in range(num_valid)] - - valid_tensor = torch.tensor(valid_indices, device=device, dtype=torch.int32) + dst_indices_tensor = torch.tensor(valid_indices, device=device, dtype=torch.int32) + src_indices_tensor = torch.arange(dst_indices_tensor.shape[0], device=device, dtype=torch.int32) last_steps_tensor = torch.tensor(last_step_pos, device=device, dtype=torch.int32) + valid_mask = last_steps_tensor >= 0 - valid_state_indices = valid_tensor[valid_mask].to(torch.int64) + dst_state_indices = dst_indices_tensor[valid_mask].to(torch.int64) + src_state_indices = src_indices_tensor[valid_mask].to(torch.int64) valid_last_steps = last_steps_tensor[valid_mask].to(torch.int64) - dst_cache[:, valid_state_indices, :] = src_cache[ - :, valid_state_indices, valid_last_steps + # prepare output verify + dst_cache[:, dst_state_indices, :] = src_cache[ + :, src_state_indices, valid_last_steps ] - move_intermediate_cache(dst_cache_clone, src_cache, valid_tensor, last_steps_tensor) + move_intermediate_cache(dst_cache_clone, src_cache, dst_indices_tensor, src_indices_tensor, last_steps_tensor) assert_close("move_cache", dst_cache, dst_cache_clone, 1e-3) From 2d2fb6c9b764360ce6d501b1272d44ccfcb988a5 Mon Sep 17 00:00:00 2001 From: silencejade <13120475055@163.com> Date: Thu, 2 Apr 2026 16:33:27 +0800 Subject: [PATCH 2/2] fix test && add input check --- .../sgl_kernel_npu/mamba/mamba_state_update_triton.py | 3 ++- tests/python/sgl_kernel_npu/test_mamba_state_update.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py b/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py index 14b30a7a6..08542cd44 100644 --- a/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py +++ b/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py @@ -113,7 +113,8 @@ def move_intermediate_cache( dst_layer_stride, dst_size_stride = int(ssm_states.stride()[0]), int( ssm_states.stride()[1] ) - assert len(dst_indices_tensor) == len(last_steps_tensor), "Lengths must match" + assert len(dst_indices_tensor) == len(last_steps_tensor), "Destination indices lengths must match" + assert len(src_indices_tensor) == len(last_steps_tensor), "Source indices lengths must match" # Grid: one thread per valid index grid = (len(dst_indices_tensor),) diff --git a/tests/python/sgl_kernel_npu/test_mamba_state_update.py b/tests/python/sgl_kernel_npu/test_mamba_state_update.py index 2cb60b9d6..e755f9767 100644 --- a/tests/python/sgl_kernel_npu/test_mamba_state_update.py +++ b/tests/python/sgl_kernel_npu/test_mamba_state_update.py @@ -24,7 +24,6 @@ def get_err_ratio(x, y): def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): abs_atol = get_abs_err(ref, tri) msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" - print(msg) error_rate = get_err_ratio(ref, tri) if abs_atol <= err_atol: return