Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sgl_kernel_npu.mamba.causal_conv1d import (
causal_conv1d_fn_npu,
causal_conv1d_update_npu,
causal_conv1d_update_v2,
)

from sglang.srt.hardware_backend.npu.attention.ascend_hybrid_linear_attn_backend import (
Expand Down Expand Up @@ -224,9 +225,7 @@ def forward_extend(
else:
has_initial_states = forward_batch.extend_prefix_lens > 0
if is_target_verify:
draft_token_num = forward_batch.spec_info.draft_token_num
num_token_padding = mixed_qkv.shape[0]
batch_size = cache_indices.shape[0]
if (
not self.graph_mode
and forward_batch.num_token_non_padded_cpu != num_token_padding
Expand All @@ -236,23 +235,24 @@ def forward_extend(
b = b[: forward_batch.num_token_non_padded_cpu]
seq_len = forward_batch.num_token_non_padded_cpu

mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1)
num_accept_tokens = torch.full(
batch_size = cache_indices.shape[0]
draft_token_num = forward_batch.spec_info.draft_token_num
num_accepted_tokens = torch.full(
(batch_size,),
draft_token_num,
dtype=torch.int32,
device=mixed_qkv.device,
)
mixed_qkv = torch.ops.npu.causal_conv1d_update(
mixed_qkv_reshaped,
layer.conv_weights.transpose(0, 1).contiguous(),
conv_states,
cache_indices,
layer.bias,
num_accept_tokens,
None,
layer.activation == "silu",
self.pad_slot_id,
mixed_qkv = causal_conv1d_update_v2(
x=mixed_qkv.view(batch_size, draft_token_num, -1).contiguous(),
conv_state=conv_states.contiguous(),
Comment thread
iridiumine marked this conversation as resolved.
weight=layer.conv_weights.transpose(0, 1).contiguous(),
bias=layer.bias,
activation=layer.activation,
conv_state_indices=cache_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=-1,
Comment thread
iridiumine marked this conversation as resolved.
validate_data=False,
).view(seq_len, -1)
Comment thread
iridiumine marked this conversation as resolved.
else:
mixed_qkv = mixed_qkv.transpose(0, 1)
Expand Down
Loading