Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions vllm_ascend/ops/triton/mamba/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ def causal_conv1d_update_npu(
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
"""
weight = weight.transpose(0, 1).contiguous()
conv_state = conv_state.transpose(1, 2).contiguous()
if validate_data:
assert pad_slot_id is not None
assert x.stride(1) == 1
Expand All @@ -608,40 +610,33 @@ def causal_conv1d_update_npu(
unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
x = x.unsqueeze(1)

if query_start_loc is None:
batch, dim, seqlen = x.shape
batch, seqlen, dim = x.shape
else:
assert conv_state_indices is not None
batch = conv_state_indices.size(0)
dim = x.size(1)
seqlen = max_query_len

_, width = weight.shape
num_cache_lines, _, state_len_total = conv_state.size()

if validate_data:
assert dim == weight.size(0)
assert conv_state.stride(-2) == 1
assert state_len_total >= width - 1
assert num_cache_lines >= batch
assert weight.stride(1) == 1
width, _ = weight.shape
num_cache_lines, state_len_total,_ = conv_state.size()

# overwrite-on-x strategy same as original
out = x

stride_w_dim, stride_w_width = weight.stride()
stride_w_width, stride_w_dim = weight.stride()
if query_start_loc is None:
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_x_seq, stride_x_token,stride_x_dim = x.stride()
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
else:
stride_x_token, stride_x_dim = x.stride()
stride_x_seq = 0
stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0

stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride(
)
stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0
Expand All @@ -657,7 +652,7 @@ def causal_conv1d_update_npu(
#keep program count around ~[80..160]
# vector core 40
# TODO: use driver to get the vector core num
CORE_HINT = 40
CORE_HINT = 40
# channel tile: 512 when dim large (reduce tasks), else 256
block_n = 512 if dim >= 512 else 256
g = triton.cdiv(dim, block_n)
Expand All @@ -674,14 +669,13 @@ def causal_conv1d_update_npu(
b_tile = 8

# token chunk based on block_n (32KB UB idea); conservative
t_chunk = 20 if block_n == 512 else 48
t_chunk = 1 if block_n == 512 else 48

def grid(META):
return (
triton.cdiv(batch, META["B_TILE"]),
triton.cdiv(dim, META["BLOCK_N"]),
)

_causal_conv1d_update_kernel_npu_tiled[grid](
x,
weight,
Expand Down Expand Up @@ -725,5 +719,5 @@ def grid(META):
)

if unsqueeze:
out = out.squeeze(-1)
out = out.squeeze(1)
return out.to(original_x_dtype)