From 6d940a5df858d0ea48924ae2ba7882f87e033d36 Mon Sep 17 00:00:00 2001 From: SunnyLee219 <3294305115@qq.com> Date: Mon, 19 Jan 2026 10:16:52 +0800 Subject: [PATCH 1/3] update causal_conv1d_update Signed-off-by: SunnyLee219 <3294305115@qq.com> --- vllm_ascend/ops/triton/mamba/causal_conv1d.py | 35 ++++++++----------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index 9fb9465b0a5..a51e281cb8b 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -595,6 +595,9 @@ def causal_conv1d_update_npu( indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ + weight = weight.transpose(0, 1).contiguous() + print("weight's shape: ", weight.size()) + conv_state = conv_state.transpose(1, 2).contiguous() if validate_data: assert pad_slot_id is not None assert x.stride(1) == 1 @@ -608,40 +611,33 @@ def causal_conv1d_update_npu( unsqueeze = query_start_loc is None and x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 - x = x.unsqueeze(-1) + x = x.unsqueeze(1) if query_start_loc is None: - batch, dim, seqlen = x.shape + batch, seqlen, dim = x.shape else: assert conv_state_indices is not None batch = conv_state_indices.size(0) - dim = x.size(1) + dim = x.size(2) seqlen = max_query_len - _, width = weight.shape - num_cache_lines, _, state_len_total = conv_state.size() - - if validate_data: - assert dim == weight.size(0) - assert conv_state.stride(-2) == 1 - assert state_len_total >= width - 1 - assert num_cache_lines >= batch - assert weight.stride(1) == 1 + width, _ = weight.shape + num_cache_lines, state_len_total,_ = conv_state.size() # overwrite-on-x strategy same as original out = x - stride_w_dim, stride_w_width = weight.stride() + stride_w_width, stride_w_dim = weight.stride() if query_start_loc is None: - stride_x_seq, stride_x_dim, stride_x_token = x.stride() - stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_x_seq, stride_x_token,stride_x_dim = x.stride() + stride_o_seq, stride_o_token, stride_o_dim = out.stride() else: stride_x_token, stride_x_dim = x.stride() stride_x_seq = 0 stride_o_token, stride_o_dim = out.stride() stride_o_seq = 0 - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride( ) stride_state_indices = conv_state_indices.stride( 0) if conv_state_indices is not None else 0 @@ -657,7 +653,7 @@ def causal_conv1d_update_npu( #keep program count around ~[80..160] # vector core 40 # TODO: use driver to get the vector core num - CORE_HINT = 40 + CORE_HINT = 40 # channel tile: 512 when dim large (reduce tasks), else 256 block_n = 512 if dim >= 512 else 256 g = triton.cdiv(dim, block_n) @@ -674,14 +670,13 @@ def causal_conv1d_update_npu( b_tile = 8 # token chunk based on block_n (32KB UB idea); conservative - t_chunk = 20 if block_n == 512 else 48 + t_chunk = 1 if block_n == 512 else 48 def grid(META): return ( triton.cdiv(batch, META["B_TILE"]), triton.cdiv(dim, META["BLOCK_N"]), ) - _causal_conv1d_update_kernel_npu_tiled[grid]( x, weight, @@ -725,5 +720,5 @@ def grid(META): ) if unsqueeze: - out = out.squeeze(-1) + out = out.squeeze(1) return out.to(original_x_dtype) From 35b85b65e887803890e6a70cdb4bd1447e951443 Mon Sep 17 00:00:00 2001 From: SunnyLee219 <3294305115@qq.com> Date: Mon, 19 Jan 2026 10:29:45 +0800 Subject: [PATCH 2/3] update causal_conv1d_update Signed-off-by: SunnyLee219 <3294305115@qq.com> --- vllm_ascend/ops/triton/mamba/causal_conv1d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index a51e281cb8b..c4839d76c6e 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -596,7 +596,6 @@ def causal_conv1d_update_npu( out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ weight = weight.transpose(0, 1).contiguous() - print("weight's shape: ", weight.size()) conv_state = conv_state.transpose(1, 2).contiguous() if validate_data: assert pad_slot_id is not None From 4ebe79ab6631fafc97b8779d05b797488891c507 Mon Sep 17 00:00:00 2001 From: SunnyLee219 <3294305115@qq.com> Date: Wed, 21 Jan 2026 11:34:23 +0800 Subject: [PATCH 3/3] Fix bug Signed-off-by: SunnyLee219 <3294305115@qq.com> --- vllm_ascend/ops/triton/mamba/causal_conv1d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index c4839d76c6e..84c330b5df2 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -617,7 +617,7 @@ def causal_conv1d_update_npu( else: assert conv_state_indices is not None batch = conv_state_indices.size(0) - dim = x.size(2) + dim = x.size(1) seqlen = max_query_len width, _ = weight.shape