Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,14 @@ def get_splitfuse_attn_mask(

class AscendAttnBackend(AttentionBackend):

def __init__(self, model_runner: ModelRunner):
def __init__(self, model_runner: ModelRunner, speculative_step_id: int = 0):
super().__init__()
self.forward_metadata = None
self.device = model_runner.device
self.speculative_step_id = speculative_step_id
self.speculative_step_offset_npu = torch.tensor(
speculative_step_id + 1, device="npu"
)
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
if self.use_mla:
Expand Down Expand Up @@ -287,6 +291,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
seq_lens_max = forward_batch.seq_lens.max()
if forward_batch.forward_mode.is_target_verify():
seq_lens_max += self.speculative_num_draft_tokens
elif (
forward_batch.forward_mode.is_decode_or_idle()
and forward_batch.spec_info is not None
):
seq_lens_max += self.speculative_step_id + 1
self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :seq_lens_max
Expand Down Expand Up @@ -329,6 +338,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

if forward_batch.forward_mode.is_target_verify():
self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens
elif (
forward_batch.forward_mode.is_decode_or_idle()
and forward_batch.spec_info is not None
):
self.forward_metadata.seq_lens_cpu_int += self.speculative_step_id + 1

if (
self.use_mla
Expand Down Expand Up @@ -455,6 +469,8 @@ def init_forward_metadata_replay_cuda_graph(
max_len = seq_lens_cpu[:bs].max().item()
if forward_mode.is_target_verify():
max_len += self.speculative_num_draft_tokens
elif forward_mode.is_decode_or_idle() and spec_info is not None:
max_len += self.speculative_step_id + 1
max_seq_pages = (max_len + self.page_size - 1) // self.page_size

if self.is_hybrid_swa:
Expand All @@ -476,6 +492,8 @@ def init_forward_metadata_replay_cuda_graph(

if forward_mode.is_target_verify():
seq_lens = seq_lens + self.speculative_num_draft_tokens
elif forward_mode.is_decode_or_idle() and spec_info is not None:
seq_lens = seq_lens + self.speculative_step_offset_npu
metadata.seq_lens[:bs].copy_(seq_lens[:bs])

self.forward_metadata = metadata
Expand Down Expand Up @@ -1902,8 +1920,10 @@ def __init__(
self.speculative_num_steps = speculative_num_steps

self.attn_backends = []
for _ in range(self.speculative_num_steps):
self.attn_backends.append(AscendAttnBackend(model_runner))
for step_id in range(self.speculative_num_steps):
self.attn_backends.append(
AscendAttnBackend(model_runner, speculative_step_id=step_id)
)

def common_template(self, forward_batch: ForwardBatch, call_fn: int):
assert forward_batch.spec_info is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,28 @@ def _get_update_attr_name(self):
def _get_update_attr_type(self):
return self.attr_type[AttentionArch.MLA]

def _replay_update(self, seq_lens):
def _replay_update(self, seq_lens_list):
if isinstance(self.update_attr_type, torch.Tensor):
seq_lens = torch.from_numpy(np.array(seq_lens).astype(np.int32))
seq_lens = torch.from_numpy(np.array(seq_lens_list).astype(np.int32))

self.graphs[self.bs].update(
cpu_update_input=[{self.update_attr_name: seq_lens}]
cpu_update_input=[
{self.update_attr_name: seq_lens} for seq_lens in seq_lens_list
]
)

def _replay(self, forward_batch: ForwardBatch):
self.update_attr_name = self._get_update_attr_name()
self.update_attr_type = self._get_update_attr_type()
if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
seq_lens = forward_batch.seq_lens_cpu.tolist() + [0] * (
self.bs - self.raw_bs
seq_lens_for_each_draft_step = []
for speculative_step_id in range(self.speculative_num_steps - 1):
seq_lens_cpu = forward_batch.seq_lens_cpu + speculative_step_id + 1
seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)
seq_lens_for_each_draft_step.append(seq_lens)
thread = threading.Thread(
target=self._replay_update, args=(seq_lens_for_each_draft_step,)
)
thread = threading.Thread(target=self._replay_update, args=(seq_lens,))
thread.start()
self.graphs[self.bs].replay()
thread.join()
Expand Down
Loading