diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index cea097cb053..8d07d993308 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -32,7 +32,7 @@ mha_batch_prefill_func, paged_attention_ragged, ) - from aiter.mla import mla_decode_fwd + from aiter.mla import mla_decode_fwd, mla_prefill_fwd except ImportError: print( "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." @@ -52,10 +52,8 @@ class ForwardMetadata: kv_indices: torch.Tensor qo_indptr: torch.Tensor kv_last_page_len: torch.Tensor - max_extend_len: int - max_prefix_extend_len: int max_q_len: int - max_kv_len: int + max_kv_len: Optional[int] global_workspace_buffer = None @@ -71,10 +69,17 @@ def __init__( kv_indptr_buf: Optional[torch.Tensor] = None, ): super().__init__() + # Lazy import to avoid the initialization of cuda context + from sglang.srt.layers.attention.triton_ops.extend_attention import ( + extend_attention_fwd, + ) + + self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) self.device = model_runner.device self.is_multimodal = model_runner.model_config.is_multimodal self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) @@ -157,13 +162,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): spec_info = forward_batch.spec_info qo_indptr = None kv_last_page_len = None - max_extend_len = None + max_q_len = None if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.zeros( + kv_indices = torch.empty( forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( @@ -183,39 +188,35 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): qo_indptr = self.qo_indptr_[: bs + 1] qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) kv_last_page_len = self.kv_last_page_len[:bs] - max_extend_len = 1 + max_q_len = 1 self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, qo_indptr, kv_last_page_len, - max_extend_len, - None, - None, + max_q_len, None, ) elif forward_batch.forward_mode.is_draft_extend(): if self.use_mla: - prefix_lens = forward_batch.extend_prefix_lens - self.mla_indices_updater_prefill.update( - forward_batch.req_pool_indices, - prefix_lens, - prefix_lens.sum().item(), - forward_batch.extend_seq_lens, - encoder_lens=forward_batch.encoder_lens, - spec_info=None, + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + self.req_to_token, + ) ) self.forward_metadata = ForwardMetadata( - self.mla_indices_updater_prefill.kv_indptr, - self.mla_indices_updater_prefill.kv_indices, - self.mla_indices_updater_prefill.qo_indptr, - self.mla_indices_updater_prefill.kv_last_page_len, - self.mla_indices_updater_prefill.max_extend_len, - self.mla_indices_updater_prefill.max_prefix_extend_len, - None, - None, + kv_indptr, + kv_indices, + qo_indptr, + # self.mla_indices_updater_prefill.kv_last_page_len, + self.kv_last_page_len[:bs], + max(forward_batch.extend_seq_lens_cpu), + forward_batch.seq_lens_cpu.max().item(), ) else: self.indices_updater_prefill.update( @@ -231,30 +232,47 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.indices_updater_prefill.kv_indices, None, None, - None, - None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, ) elif forward_batch.forward_mode.is_target_verify(): if self.use_mla: - prefix_lens = forward_batch.extend_prefix_lens - self.mla_indices_updater_prefill.update( + draft_num = spec_info.draft_token_num + kv_lens = forward_batch.seq_lens + draft_num + kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs + device = forward_batch.seq_lens.device + + qo_indptr = torch.arange( + 0, + (1 + bs) * draft_num, + step=draft_num, + dtype=torch.int32, + device=device, + ) + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + kv_lens_sum, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, forward_batch.req_pool_indices, - prefix_lens, - prefix_lens.sum().item(), - forward_batch.extend_seq_lens, - encoder_lens=forward_batch.encoder_lens, - spec_info=None, + kv_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), ) self.forward_metadata = ForwardMetadata( - self.mla_indices_updater_prefill.kv_indptr, - self.mla_indices_updater_prefill.kv_indices, - self.mla_indices_updater_prefill.qo_indptr, - self.mla_indices_updater_prefill.kv_last_page_len, - self.mla_indices_updater_prefill.max_extend_len, - self.mla_indices_updater_prefill.max_prefix_extend_len, - None, + kv_indptr, + kv_indices, + qo_indptr, + # self.mla_indices_updater_prefill.kv_last_page_len, + self.kv_last_page_len[:bs], + draft_num, None, ) else: @@ -271,8 +289,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.indices_updater_prefill.kv_indices, None, None, - None, - None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, ) @@ -283,25 +299,26 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): extend_no_prefix = False else: extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) - if self.use_mla: self.mla_indices_updater_prefill.update( forward_batch.req_pool_indices, - prefix_lens, - prefix_lens.sum().item(), + forward_batch.extend_prefix_lens, + sum(forward_batch.extend_prefix_lens_cpu), forward_batch.extend_seq_lens, - encoder_lens=forward_batch.encoder_lens, + max(forward_batch.extend_seq_lens_cpu), + forward_batch.seq_lens_cpu.max().item(), spec_info=None, ) + self.mla_indices_updater_prefill.kv_indptr += ( + self.mla_indices_updater_prefill.qo_indptr + ) self.forward_metadata = ForwardMetadata( self.mla_indices_updater_prefill.kv_indptr, self.mla_indices_updater_prefill.kv_indices, self.mla_indices_updater_prefill.qo_indptr, - self.mla_indices_updater_prefill.kv_last_page_len, - self.mla_indices_updater_prefill.max_extend_len, - self.mla_indices_updater_prefill.max_prefix_extend_len, - None, - None, + self.kv_last_page_len[:bs], + self.mla_indices_updater_prefill.max_q_len, + self.mla_indices_updater_prefill.max_kv_len, ) else: self.indices_updater_prefill.update( @@ -317,8 +334,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.indices_updater_prefill.kv_indices, None, None, - None, - None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, ) @@ -359,7 +374,7 @@ def init_forward_metadata_capture_cuda_graph( if forward_mode.is_decode_or_idle(): qo_indptr = None kv_last_page_len = None - max_extend_len = None + max_q_len = None if spec_info is None: kv_indptr = self.kv_indptr @@ -383,17 +398,15 @@ def init_forward_metadata_capture_cuda_graph( qo_indptr[1 : bs + 1] = torch.cumsum( self.cuda_graph_kv_last_page_len[:bs], dim=0 ) - max_extend_len = 1 kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, qo_indptr, kv_last_page_len, - max_extend_len, - None, - None, + max_q_len, None, ) @@ -419,18 +432,15 @@ def init_forward_metadata_capture_cuda_graph( kv_indices, self.req_to_token.stride(0), ) - - max_extend_len = self.num_draft_tokens - kv_last_page_len = None + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = self.num_draft_tokens self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, qo_indptr, kv_last_page_len, - max_extend_len, - None, - None, + max_q_len, None, ) else: @@ -448,12 +458,41 @@ def init_forward_metadata_capture_cuda_graph( self.indices_updater_prefill.kv_indices, None, None, - None, - None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, ) - + elif forward_mode.is_draft_extend(): + num_tokens_per_bs = self.speculative_num_steps + 1 + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + bs * num_tokens_per_bs + 1, + step=num_tokens_per_bs, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = num_tokens_per_bs + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + ) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -488,13 +527,44 @@ def init_forward_metadata_replay_cuda_graph( kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices elif forward_mode.is_target_verify(): - self.indices_updater_prefill.update( - req_pool_indices[:bs], - seq_lens[:bs], - seq_lens_sum, - prefix_lens=None, - encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, - spec_info=spec_info, + bs = len(req_pool_indices) + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_lens = seq_lens + self.num_draft_tokens + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + kv_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + elif forward_mode.is_draft_extend(): + seq_lens = seq_lens[:bs] + accept_lens = spec_info.accept_length[:bs] + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), ) else: raise ValueError("Invalid forward mode") @@ -530,11 +600,10 @@ def forward_extend( ) if self.use_mla: - max_extend_len = self.forward_metadata.max_extend_len - max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len + max_q_len = self.forward_metadata.max_q_len + max_kv_len = self.forward_metadata.max_kv_len kv_indptr = self.forward_metadata.kv_indptr kv_indices = self.forward_metadata.kv_indices - kv_last_page_lens = self.forward_metadata.kv_last_page_len qo_indptr = self.forward_metadata.qo_indptr K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) @@ -552,8 +621,8 @@ def forward_extend( v, qo_indptr, qo_indptr, - max_extend_len, - max_extend_len, + max_q_len, + max_q_len, softmax_scale=layer.scaling, causal=True, ) @@ -599,12 +668,71 @@ def forward_extend( v, qo_indptr, kv_indptr, - max_extend_len, - max_prefix_extend_len, + max_q_len, + max_kv_len, softmax_scale=layer.scaling, causal=True, ) return o + elif forward_batch.forward_mode.is_target_verify(): + o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) + mla_decode_fwd( + q, + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + o, + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + layer.scaling, + layer.logit_cap, + ) + K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) + return o + elif forward_batch.forward_mode.is_draft_extend(): + o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) + causal = True + sliding_window_size = -1 + kv_indptr = self.forward_metadata.kv_indptr + kv_indices = self.forward_metadata.kv_indices + mla_prefill_fwd( + q, + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + o, + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + layer.scaling, + layer.logit_cap, + ) + K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) + return o + # self.extend_attention_fwd( + # q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + # k.contiguous(), + # v.contiguous(), + # o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + # forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + # forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + # self.forward_metadata.qo_indptr, + # kv_indptr, + # kv_indices, + # None, + # causal, + # None, + # self.forward_metadata.max_q_len, + # layer.scaling, + # layer.logit_cap, + # sliding_window_size, + # ) + # return o + else: + raise ValueError( + f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}" + ) else: k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id @@ -662,7 +790,7 @@ def forward_decode( self.forward_metadata.kv_indptr, self.forward_metadata.kv_indices, self.forward_metadata.kv_last_page_len, - self.forward_metadata.max_extend_len, + self.forward_metadata.max_q_len, layer.scaling, layer.logit_cap, ) @@ -816,16 +944,17 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.kv_indices = None self.qo_indptr = None self.kv_last_page_len = None - self.max_extend_len = 0 - self.max_prefix_extend_len = 0 + self.max_q_len = 0 + self.max_kv_len = 0 def update( self, req_pool_indices: torch.Tensor, - prefix_lens: torch.Tensor, - prefix_lens_sum: int, + kv_lens: torch.Tensor, + kv_lens_sum: int, extend_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], + max_q_len: int, + max_kv_len: int, spec_info: Optional[SpecInfo], ): # Keep the signature for type checking. It will be assigned during runtime. @@ -834,33 +963,30 @@ def update( def update_single_wrapper( self, req_pool_indices: torch.Tensor, - prefix_lens: torch.Tensor, - prefix_lens_sum: int, + kv_lens: torch.Tensor, + kv_lens_sum: int, extend_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], + max_q_len: int, + max_kv_len: int, spec_info: Optional[SpecInfo], ): - - paged_kernel_lens = prefix_lens - paged_kernel_lens_sum = prefix_lens_sum - bs = len(req_pool_indices) kv_indptr = self.attn_backend.kv_indptr if spec_info is None: # Normal extend - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - paged_kernel_lens_sum, + kv_lens_sum, dtype=torch.int32, device=req_pool_indices.device, ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, - paged_kernel_lens, + kv_lens, kv_indptr, None, kv_indices, @@ -870,16 +996,12 @@ def update_single_wrapper( qo_indptr = self.attn_backend.qo_indptr qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] - - max_extend_len = torch.max(extend_lens).item() - max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item() - kv_indptr += qo_indptr else: kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( req_pool_indices, - paged_kernel_lens, - paged_kernel_lens_sum, + kv_lens, + kv_lens_sum, self.req_to_token, ) ) @@ -887,5 +1009,146 @@ def update_single_wrapper( self.kv_indptr = kv_indptr self.kv_indices = kv_indices self.qo_indptr = qo_indptr - self.max_extend_len = max_extend_len - self.max_prefix_extend_len = max_prefix_extend_len + self.max_q_len = max_q_len + self.max_kv_len = max_kv_len + + +class AiterMultiStepDraftBackend: + """ + Wrap multiple triton attention backends as one for multiple consecutive + draft decoding steps. + """ + + def __init__( + self, + model_runner: ModelRunner, + topk: int, + speculative_num_steps: int, + ): + from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + max_bs = model_runner.req_to_token_pool.size * self.topk + self.kv_indptr = torch.zeros( + ( + self.speculative_num_steps, + max_bs + 1, + ), + dtype=torch.int32, + device=model_runner.device, + ) + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + AiterAttnBackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + ) + ) + self.max_context_len = self.attn_backends[0].max_context_len + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.device = model_runner.device + # Cached variables for generate_draft_decode_kv_indices + self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + self.page_size = model_runner.server_args.page_size + assert self.page_size == 1, "Page size must be 1" + + def common_template( + self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int + ): + num_seqs = forward_batch.batch_size + bs = self.topk * num_seqs + seq_lens_sum = forward_batch.seq_lens_sum + + self.generate_draft_decode_kv_indices[ + (self.speculative_num_steps, num_seqs, self.topk) + ]( + forward_batch.req_pool_indices, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.seq_lens, + kv_indices_buffer, + self.kv_indptr, + forward_batch.positions, + self.pool_len, + kv_indices_buffer.shape[1], + self.kv_indptr.shape[1], + triton.next_power_of_2(num_seqs), + triton.next_power_of_2(self.speculative_num_steps), + triton.next_power_of_2(bs), + self.page_size, + ) + + for i in range(self.speculative_num_steps): + forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] + forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ + : seq_lens_sum * self.topk + bs * (i + 1) + ] + call_fn(i, forward_batch) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + kv_indices = torch.empty( + ( + self.speculative_num_steps, + forward_batch.batch_size * self.topk * self.max_context_len, + ), + dtype=torch.int32, + device=self.device, + ) + + def call_fn(i, forward_batch): + forward_batch.spec_info.kv_indptr = ( + forward_batch.spec_info.kv_indptr.clone() + ) + forward_batch.spec_info.kv_indices = ( + forward_batch.spec_info.kv_indices.clone() + ) + self.attn_backends[i].init_forward_metadata(forward_batch) + + self.common_template(forward_batch, kv_indices, call_fn) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + self.cuda_graph_kv_indices = torch.zeros( + (self.speculative_num_steps, max_num_tokens * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state( + max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] + ) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=None, + ) + + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index dca2cbfb799..16a2230f928 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1722,6 +1722,7 @@ def get_model_worker_batch( or attention_backend_str == "cutlass_mla" or attention_backend_str == "ascend" or attention_backend_str == "trtllm_mha" + or attention_backend_str == "aiter" or global_server_args_dict["enable_two_batch_overlap"] ): seq_lens_cpu = ( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 376cd029c60..8da0549e920 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -226,6 +226,22 @@ def init_attention_backend(self): self.draft_model_runner, skip_prefill=False, ) + elif self.server_args.attention_backend == "aiter": + from sglang.srt.layers.attention.aiter_backend import ( + AiterAttnBackend, + AiterMultiStepDraftBackend, + ) + + self.draft_attn_backend = AiterMultiStepDraftBackend( + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + self.draft_extend_attn_backend = AiterAttnBackend( + self.draft_model_runner, + skip_prefill=False, + ) + self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "fa3": from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionBackend,