diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 93b44de927b7..65ba2d6186bd 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -208,10 +208,14 @@ def get_splitfuse_attn_mask( class AscendAttnBackend(AttentionBackend): - def __init__(self, model_runner: ModelRunner): + def __init__(self, model_runner: ModelRunner, speculative_step_id: int = 0): super().__init__() self.forward_metadata = None self.device = model_runner.device + self.speculative_step_id = speculative_step_id + self.speculative_step_offset_npu = torch.tensor( + speculative_step_id + 1, device="npu" + ) self.page_size = model_runner.page_size self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA if self.use_mla: @@ -287,6 +291,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens_max = forward_batch.seq_lens.max() if forward_batch.forward_mode.is_target_verify(): seq_lens_max += self.speculative_num_draft_tokens + elif ( + forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.spec_info is not None + ): + seq_lens_max += self.speculative_step_id + 1 self.forward_metadata.block_tables = ( forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, :seq_lens_max @@ -329,6 +338,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_target_verify(): self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens + elif ( + forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.spec_info is not None + ): + self.forward_metadata.seq_lens_cpu_int += self.speculative_step_id + 1 if ( self.use_mla @@ -455,6 +469,8 @@ def init_forward_metadata_replay_cuda_graph( max_len = seq_lens_cpu[:bs].max().item() if forward_mode.is_target_verify(): max_len += self.speculative_num_draft_tokens + elif forward_mode.is_decode_or_idle() and spec_info is not None: + max_len += self.speculative_step_id + 1 max_seq_pages = (max_len + self.page_size - 1) // self.page_size if self.is_hybrid_swa: @@ -476,6 +492,8 @@ def init_forward_metadata_replay_cuda_graph( if forward_mode.is_target_verify(): seq_lens = seq_lens + self.speculative_num_draft_tokens + elif forward_mode.is_decode_or_idle() and spec_info is not None: + seq_lens = seq_lens + self.speculative_step_offset_npu metadata.seq_lens[:bs].copy_(seq_lens[:bs]) self.forward_metadata = metadata @@ -1902,8 +1920,10 @@ def __init__( self.speculative_num_steps = speculative_num_steps self.attn_backends = [] - for _ in range(self.speculative_num_steps): - self.attn_backends.append(AscendAttnBackend(model_runner)) + for step_id in range(self.speculative_num_steps): + self.attn_backends.append( + AscendAttnBackend(model_runner, speculative_step_id=step_id) + ) def common_template(self, forward_batch: ForwardBatch, call_fn: int): assert forward_batch.spec_info is not None diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py index 4ffc5fdd2d57..77c5d4f2405a 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py @@ -83,22 +83,28 @@ def _get_update_attr_name(self): def _get_update_attr_type(self): return self.attr_type[AttentionArch.MLA] - def _replay_update(self, seq_lens): + def _replay_update(self, seq_lens_list): if isinstance(self.update_attr_type, torch.Tensor): - seq_lens = torch.from_numpy(np.array(seq_lens).astype(np.int32)) + seq_lens = torch.from_numpy(np.array(seq_lens_list).astype(np.int32)) self.graphs[self.bs].update( - cpu_update_input=[{self.update_attr_name: seq_lens}] + cpu_update_input=[ + {self.update_attr_name: seq_lens} for seq_lens in seq_lens_list + ] ) def _replay(self, forward_batch: ForwardBatch): self.update_attr_name = self._get_update_attr_name() self.update_attr_type = self._get_update_attr_type() if not is_deepseek_nsa(self.model_runner.model_config.hf_config): - seq_lens = forward_batch.seq_lens_cpu.tolist() + [0] * ( - self.bs - self.raw_bs + seq_lens_for_each_draft_step = [] + for speculative_step_id in range(self.speculative_num_steps - 1): + seq_lens_cpu = forward_batch.seq_lens_cpu + speculative_step_id + 1 + seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs) + seq_lens_for_each_draft_step.append(seq_lens) + thread = threading.Thread( + target=self._replay_update, args=(seq_lens_for_each_draft_step,) ) - thread = threading.Thread(target=self._replay_update, args=(seq_lens,)) thread.start() self.graphs[self.bs].replay() thread.join()