Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
def move_cache_dynamic_last_kernel_h_block(
dst_cache_ptr,
src_cache_ptr,
valid_indices_ptr,
dst_indices_ptr,
src_indices_ptr,
last_steps_ptr,
layer_stride,
size_stride,
Expand All @@ -39,7 +40,8 @@ def move_cache_dynamic_last_kernel_h_block(
valid_id = tl.program_id(0)

# Load actual indices
valid_idx_val = tl.load(valid_indices_ptr + valid_id)
dst_idx_val = tl.load(dst_indices_ptr + valid_id)
src_idx_val = tl.load(src_indices_ptr + valid_id)
last_step_val = tl.load(last_steps_ptr + valid_id)
if last_step_val < 0:
return
Expand All @@ -52,12 +54,12 @@ def move_cache_dynamic_last_kernel_h_block(
src_base_addr = (
src_cache_ptr
+ tl.cast(l, tl.int64) * layer_stride
+ tl.cast(valid_idx_val, tl.int64) * size_stride
+ tl.cast(src_idx_val, tl.int64) * size_stride
)
dst_base_addr = (
dst_cache_ptr
+ tl.cast(l, tl.int64) * dst_layer_stride
+ tl.cast(valid_idx_val, tl.int64) * dst_size_stride
+ tl.cast(dst_idx_val, tl.int64) * dst_size_stride
)
src_addr = src_base_addr + tl.cast(last_step_val, tl.int64) * draft_stride

Expand All @@ -84,7 +86,8 @@ def move_cache_dynamic_last_kernel_h_block(
def move_intermediate_cache(
ssm_states,
intermediate_state_cache,
valid_tensor,
dst_indices_tensor,
src_indices_tensor,
last_steps_tensor,
h_block_size=2,
):
Expand All @@ -94,7 +97,8 @@ def move_intermediate_cache(
Args:
ssm_states: Destination SSM states tensor
intermediate_state_cache: Source intermediate state cache
valid_tensor: Valid indices tensor
dst_indices_tensor: Valid destination indices tensor
src_indices_tensor: Valid source indices tensor
last_steps_tensor: Last steps tensor
h_block_size: Block size for h dimension processing
"""
Expand All @@ -109,15 +113,17 @@ def move_intermediate_cache(
dst_layer_stride, dst_size_stride = int(ssm_states.stride()[0]), int(
ssm_states.stride()[1]
)
assert len(valid_tensor) == len(last_steps_tensor), "Lengths must match"
assert len(dst_indices_tensor) == len(last_steps_tensor), "Destination indices lengths must match"
assert len(src_indices_tensor) == len(last_steps_tensor), "Source indices lengths must match"

# Grid: one thread per valid index
grid = (len(valid_tensor),)
grid = (len(dst_indices_tensor),)

move_cache_dynamic_last_kernel_h_block[grid](
dst_cache_ptr=ssm_states,
src_cache_ptr=intermediate_state_cache,
valid_indices_ptr=valid_tensor,
dst_indices_ptr=dst_indices_tensor,
src_indices_ptr=src_indices_tensor,
last_steps_ptr=last_steps_tensor,
layer_stride=layer_stride,
size_stride=size_stride,
Expand Down
18 changes: 11 additions & 7 deletions tests/python/sgl_kernel_npu/test_mamba_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,28 @@ def test_move_intermediate_cache(
dtype: torch.dtype,
):
torch.manual_seed(42)

# prepare input data
dst_cache = torch.randn(L, S, H, V, K, device=device, dtype=dtype)
dst_cache_clone = dst_cache.clone()
src_cache = torch.randn(L, S, D, H, V, K, device=device, dtype=dtype)

# prepare input data
population = range(S)
valid_indices = random.sample(population, num_valid)
last_step_pos = [random.randint(0, D - 1) for _ in range(num_valid)]

valid_tensor = torch.tensor(valid_indices, device=device, dtype=torch.int32)
dst_indices_tensor = torch.tensor(valid_indices, device=device, dtype=torch.int32)
src_indices_tensor = torch.arange(dst_indices_tensor.shape[0], device=device, dtype=torch.int32)
last_steps_tensor = torch.tensor(last_step_pos, device=device, dtype=torch.int32)

valid_mask = last_steps_tensor >= 0
valid_state_indices = valid_tensor[valid_mask].to(torch.int64)
dst_state_indices = dst_indices_tensor[valid_mask].to(torch.int64)
src_state_indices = src_indices_tensor[valid_mask].to(torch.int64)
valid_last_steps = last_steps_tensor[valid_mask].to(torch.int64)
dst_cache[:, valid_state_indices, :] = src_cache[
:, valid_state_indices, valid_last_steps
# prepare output verify
dst_cache[:, dst_state_indices, :] = src_cache[
:, src_state_indices, valid_last_steps
]

move_intermediate_cache(dst_cache_clone, src_cache, valid_tensor, last_steps_tensor)
move_intermediate_cache(dst_cache_clone, src_cache, dst_indices_tensor, src_indices_tensor, last_steps_tensor)

assert_close("move_cache", dst_cache, dst_cache_clone, 1e-3)
Loading