diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index f1ea1de9d2e2..9ff800a9f65d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -21,8 +21,9 @@ import logging import os from contextlib import contextmanager +from dataclasses import dataclass from functools import partial -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union import torch import tqdm @@ -58,9 +59,10 @@ ForwardBatch, ForwardMode, PPProxyTensors, + compute_local_num_token_non_padded, enable_num_token_non_padded, ) -from sglang.srt.model_executor.input_buffers import GraphInputBuffers +from sglang.srt.model_executor.input_buffers import ForwardInputBuffers from sglang.srt.multiplex.pdmux_context import get_current_stream_idx, get_stream_groups from sglang.srt.utils import ( empty_context, @@ -90,6 +92,200 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner + +@dataclass +class DecodeInputBuffers(ForwardInputBuffers): + + input_ids: torch.Tensor + input_embeds: torch.Tensor + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor + out_cache_loc: torch.Tensor + positions: torch.Tensor + mrope_positions: torch.Tensor + num_token_non_padded: torch.Tensor + custom_mask: torch.Tensor + next_token_logits_buffer: torch.Tensor + mamba_track_indices: Optional[torch.Tensor] + mamba_track_mask: Optional[torch.Tensor] + global_num_tokens_gpu: torch.Tensor + global_num_tokens_for_logprob_gpu: torch.Tensor + encoder_lens: Optional[torch.Tensor] + pp_proxy_tensors: Optional[Dict[str, torch.Tensor]] + + @classmethod + def create( + cls, + *, + device: torch.device, + max_bs: int, + max_num_token: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, + dp_size: int, + pp_size: int, + is_encoder_decoder: bool, + require_mlp_tp_gather: bool, + seq_len_fill_value: int, + encoder_len_fill_value: int, + num_tokens_per_bs: int, + cache_loc_dtype: torch.dtype, + enable_mamba_track: bool, + ) -> "DecodeInputBuffers": + with torch.device(device): + input_ids = torch.zeros((max_num_token,), dtype=torch.int64) + input_embeds = torch.zeros((max_num_token, hidden_size), dtype=dtype) + req_pool_indices = torch.zeros((max_bs,), dtype=torch.int32) + seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32) + out_cache_loc = torch.zeros((max_num_token,), dtype=cache_loc_dtype) + positions = torch.zeros((max_num_token,), dtype=torch.int64) + mrope_positions = torch.zeros((3, max_num_token), dtype=torch.int64) + num_token_non_padded = torch.zeros((1,), dtype=torch.int32) + custom_mask = torch.ones( + (max_bs * seq_len_fill_value + max_num_token) * num_tokens_per_bs, + dtype=torch.bool, + ) + next_token_logits_buffer = torch.zeros( + (max_num_token, vocab_size), + dtype=torch.float, + ) + mamba_track_indices = ( + torch.zeros((max_bs,), dtype=torch.int64) + if enable_mamba_track + else None + ) + mamba_track_mask = ( + torch.zeros((max_bs,), dtype=torch.bool) if enable_mamba_track else None + ) + + if pp_size > 1: + pp_proxy_tensors = { + "hidden_states": torch.zeros((max_bs, hidden_size), dtype=dtype), + "residual": torch.zeros((max_bs, hidden_size), dtype=dtype), + } + else: + pp_proxy_tensors = None + + if is_encoder_decoder: + encoder_lens = torch.full( + (max_bs,), encoder_len_fill_value, dtype=torch.int32 + ) + else: + encoder_lens = None + + if require_mlp_tp_gather: + global_num_tokens_gpu = torch.zeros((dp_size,), dtype=torch.int32) + global_num_tokens_for_logprob_gpu = torch.zeros( + (dp_size,), dtype=torch.int32 + ) + else: + global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + global_num_tokens_for_logprob_gpu = torch.zeros((1,), dtype=torch.int32) + + # Keep seq_lens_cpu as a true CPU tensor, like the old implementation. + seq_lens_cpu = torch.full( + (max_bs,), + seq_len_fill_value, + dtype=torch.int32, + device="cpu", + ) + + return cls( + input_ids=input_ids, + input_embeds=input_embeds, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + positions=positions, + mrope_positions=mrope_positions, + num_token_non_padded=num_token_non_padded, + custom_mask=custom_mask, + next_token_logits_buffer=next_token_logits_buffer, + mamba_track_indices=mamba_track_indices, + mamba_track_mask=mamba_track_mask, + encoder_lens=encoder_lens, + global_num_tokens_gpu=global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu, + pp_proxy_tensors=pp_proxy_tensors, + ) + + def populate_from_forward_batch( + self, + *, + forward_batch: ForwardBatch, + raw_bs: int, + raw_num_token: int, + bs: int, + seq_len_fill_value: int, + require_gathered_buffer: bool, + num_tokens_per_bs: int, + nsa_enable_prefill_cp: bool, + enable_num_token_non_padded_flag: bool, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ): + if bs != raw_bs: + self.seq_lens.fill_(seq_len_fill_value) + self.out_cache_loc.zero_() + if self.mamba_track_indices is not None: + self.mamba_track_indices.zero_() + if self.mamba_track_mask is not None: + self.mamba_track_mask.fill_(False) + + # Common inputs + self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) + self.positions[:raw_num_token].copy_(forward_batch.positions) + + if ( + self.mamba_track_indices is not None + and forward_batch.mamba_track_indices is not None + ): + self.mamba_track_indices[:raw_bs].copy_(forward_batch.mamba_track_indices) + if ( + self.mamba_track_mask is not None + and forward_batch.mamba_track_mask is not None + ): + self.mamba_track_mask[:raw_bs].copy_(forward_batch.mamba_track_mask) + + if forward_batch.seq_lens_cpu is not None: + if bs != raw_bs: + self.seq_lens_cpu.fill_(seq_len_fill_value) + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + + if self.encoder_lens is not None and forward_batch.encoder_lens is not None: + self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) + + if forward_batch.mrope_positions is not None: + self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions) + + if require_gathered_buffer: + self.global_num_tokens_gpu.fill_(bs * num_tokens_per_bs) + self.global_num_tokens_for_logprob_gpu.fill_(bs * num_tokens_per_bs) + + if enable_num_token_non_padded_flag: + if require_gathered_buffer and not nsa_enable_prefill_cp: + num_tokens_per_dp = bs * num_tokens_per_bs + local = compute_local_num_token_non_padded( + global_num_token_non_padded=forward_batch.num_token_non_padded, + num_tokens_per_dp=num_tokens_per_dp, + ) + self.num_token_non_padded.copy_(local) + else: + self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) + + # Pipeline-parallel proxy tensors. + if pp_proxy_tensors is not None and self.pp_proxy_tensors is not None: + for key, buf in self.pp_proxy_tensors.items(): + src = pp_proxy_tensors.tensors[key] + dim = src.shape[0] + buf[:dim].copy_(src) + + # Detect whether the current forward pass is in capture mode is_capture_mode = False @@ -337,7 +533,7 @@ def __init__(self, model_runner: ModelRunner): if self.require_gathered_buffer: assert self.require_mlp_tp_gather or self.require_attn_tp_gather - self.buffers: GraphInputBuffers = GraphInputBuffers.create( + self.buffers: DecodeInputBuffers = DecodeInputBuffers.create( device=self.device, max_bs=self.max_bs, max_num_token=self.max_num_token, @@ -354,6 +550,7 @@ def __init__(self, model_runner: ModelRunner): cache_loc_dtype=self._cache_loc_dtype(), enable_mamba_track=enable_mamba_track, ) + self.buffers.share_buffers() self.tbo_plugin = TboCudaGraphRunnerPlugin() @@ -556,7 +753,7 @@ def _create_device_graph(self): def capture_one_batch_size( self, bs: int, forward: Callable, stream_idx: Optional[int] = None ): - buffers: GraphInputBuffers = self.buffers + buffers: DecodeInputBuffers = self.buffers graph = self._create_device_graph() stream = self.stream num_tokens = bs * self.num_tokens_per_bs @@ -798,7 +995,7 @@ def replay_prepare( index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] - seq_lens_cpu = buffers.populate_from_forward_batch( + buffers.populate_from_forward_batch( forward_batch=forward_batch, raw_bs=raw_bs, raw_num_token=raw_num_token, @@ -835,7 +1032,7 @@ def replay_prepare( buffers.encoder_lens[:bs] if self.is_encoder_decoder else None, self.capture_forward_mode, forward_batch.spec_info, - seq_lens_cpu=seq_lens_cpu, + seq_lens_cpu=buffers.seq_lens_cpu[:bs], ) # Store fields diff --git a/python/sglang/srt/model_executor/input_buffers.py b/python/sglang/srt/model_executor/input_buffers.py index f4468a70c634..240ccef3c95c 100644 --- a/python/sglang/srt/model_executor/input_buffers.py +++ b/python/sglang/srt/model_executor/input_buffers.py @@ -1,208 +1,55 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Dict, Optional +from dataclasses import dataclass, fields +from typing import Dict import torch -from sglang.srt.model_executor.forward_batch_info import ( - ForwardBatch, - PPProxyTensors, - compute_local_num_token_non_padded, -) +_forward_input_buffer_pool: Dict[str, torch.Tensor] = {} @dataclass -class GraphInputBuffers: - input_ids: torch.Tensor - input_embeds: torch.Tensor - req_pool_indices: torch.Tensor - seq_lens: torch.Tensor - seq_lens_cpu: torch.Tensor - out_cache_loc: torch.Tensor - positions: torch.Tensor - mrope_positions: torch.Tensor - num_token_non_padded: torch.Tensor - custom_mask: torch.Tensor - next_token_logits_buffer: torch.Tensor - mamba_track_indices: Optional[torch.Tensor] - mamba_track_mask: Optional[torch.Tensor] - global_num_tokens_gpu: torch.Tensor - global_num_tokens_for_logprob_gpu: torch.Tensor - encoder_lens: Optional[torch.Tensor] - pp_proxy_tensors: Optional[Dict[str, torch.Tensor]] - - @classmethod - def create( - cls, - *, - device: torch.device, - max_bs: int, - max_num_token: int, - hidden_size: int, - vocab_size: int, - dtype: torch.dtype, - dp_size: int, - pp_size: int, - is_encoder_decoder: bool, - require_mlp_tp_gather: bool, - seq_len_fill_value: int, - encoder_len_fill_value: int, - num_tokens_per_bs: int, - cache_loc_dtype: torch.dtype, - enable_mamba_track: bool, - ) -> "GraphInputBuffers": - with torch.device(device): - input_ids = torch.zeros((max_num_token,), dtype=torch.int64) - input_embeds = torch.zeros((max_num_token, hidden_size), dtype=dtype) - req_pool_indices = torch.zeros((max_bs,), dtype=torch.int32) - seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32) - out_cache_loc = torch.zeros((max_num_token,), dtype=cache_loc_dtype) - positions = torch.zeros((max_num_token,), dtype=torch.int64) - mrope_positions = torch.zeros((3, max_num_token), dtype=torch.int64) - num_token_non_padded = torch.zeros((1,), dtype=torch.int32) - custom_mask = torch.ones( - (max_bs * seq_len_fill_value + max_num_token) * num_tokens_per_bs, - dtype=torch.bool, - ) - next_token_logits_buffer = torch.zeros( - (max_num_token, vocab_size), - dtype=torch.float, - ) - mamba_track_indices = ( - torch.zeros((max_bs,), dtype=torch.int64) - if enable_mamba_track - else None - ) - mamba_track_mask = ( - torch.zeros((max_bs,), dtype=torch.bool) if enable_mamba_track else None - ) - - if pp_size > 1: - pp_proxy_tensors = { - "hidden_states": torch.zeros((max_bs, hidden_size), dtype=dtype), - "residual": torch.zeros((max_bs, hidden_size), dtype=dtype), - } - else: - pp_proxy_tensors = None - - if is_encoder_decoder: - encoder_lens = torch.full( - (max_bs,), encoder_len_fill_value, dtype=torch.int32 - ) - else: - encoder_lens = None - - if require_mlp_tp_gather: - global_num_tokens_gpu = torch.zeros((dp_size,), dtype=torch.int32) - global_num_tokens_for_logprob_gpu = torch.zeros( - (dp_size,), dtype=torch.int32 - ) +class ForwardInputBuffers: + + def _share_one_buffer(self, name: str, new_buffer: torch.Tensor) -> torch.Tensor: + + buffer_size = new_buffer.size() + buffer_stride = new_buffer.stride() + + old_buffer = _forward_input_buffer_pool.get(name, None) + if old_buffer is not None: + assert ( + new_buffer.dtype == old_buffer.dtype + ), f"Buffer {name} has different dtype than before." + assert ( + new_buffer.device == old_buffer.device + ), f"Buffer {name} has different device than before." + if old_buffer.numel() > new_buffer.numel(): + new_buffer = old_buffer + + _forward_input_buffer_pool[name] = new_buffer + return new_buffer.as_strided(buffer_size, buffer_stride) + + def share_buffers(self): + + for f in fields(self): + name = f.name + buffer = getattr(self, name) + + if buffer is None: + continue + elif isinstance(buffer, dict): + for sub_name, sub_buffer in buffer.items(): + assert isinstance( + sub_buffer, torch.Tensor + ), f"Field {name}.{sub_name} is expected to be a torch.Tensor, but got {type(sub_buffer)}." + new_buffer = self._share_one_buffer( + f"{name}.{sub_name}", sub_buffer + ) + buffer[sub_name] = new_buffer else: - global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) - global_num_tokens_for_logprob_gpu = torch.zeros((1,), dtype=torch.int32) - - # Keep seq_lens_cpu as a true CPU tensor, like the old implementation. - seq_lens_cpu = torch.full( - (max_bs,), - seq_len_fill_value, - dtype=torch.int32, - device="cpu", - ) - - return cls( - input_ids=input_ids, - input_embeds=input_embeds, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - seq_lens_cpu=seq_lens_cpu, - out_cache_loc=out_cache_loc, - positions=positions, - mrope_positions=mrope_positions, - num_token_non_padded=num_token_non_padded, - custom_mask=custom_mask, - next_token_logits_buffer=next_token_logits_buffer, - mamba_track_indices=mamba_track_indices, - mamba_track_mask=mamba_track_mask, - encoder_lens=encoder_lens, - global_num_tokens_gpu=global_num_tokens_gpu, - global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu, - pp_proxy_tensors=pp_proxy_tensors, - ) - - def populate_from_forward_batch( - self, - *, - forward_batch: ForwardBatch, - raw_bs: int, - raw_num_token: int, - bs: int, - seq_len_fill_value: int, - require_gathered_buffer: bool, - num_tokens_per_bs: int, - nsa_enable_prefill_cp: bool, - enable_num_token_non_padded_flag: bool, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Optional[torch.Tensor]: - if bs != raw_bs: - self.seq_lens.fill_(seq_len_fill_value) - self.out_cache_loc.zero_() - if self.mamba_track_indices is not None: - self.mamba_track_indices.zero_() - if self.mamba_track_mask is not None: - self.mamba_track_mask.fill_(False) - - # Common inputs - self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) - self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) - self.positions[:raw_num_token].copy_(forward_batch.positions) - - if ( - self.mamba_track_indices is not None - and forward_batch.mamba_track_indices is not None - ): - self.mamba_track_indices[:raw_bs].copy_(forward_batch.mamba_track_indices) - if ( - self.mamba_track_mask is not None - and forward_batch.mamba_track_mask is not None - ): - self.mamba_track_mask[:raw_bs].copy_(forward_batch.mamba_track_mask) - - seq_lens_cpu: Optional[torch.Tensor] = None - if forward_batch.seq_lens_cpu is not None: - if bs != raw_bs: - self.seq_lens_cpu.fill_(seq_len_fill_value) - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) - seq_lens_cpu = self.seq_lens_cpu[:bs] - - if self.encoder_lens is not None and forward_batch.encoder_lens is not None: - self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) - - if forward_batch.mrope_positions is not None: - self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions) - - if require_gathered_buffer: - self.global_num_tokens_gpu.fill_(bs * num_tokens_per_bs) - self.global_num_tokens_for_logprob_gpu.fill_(bs * num_tokens_per_bs) - - if enable_num_token_non_padded_flag: - if require_gathered_buffer and not nsa_enable_prefill_cp: - num_tokens_per_dp = bs * num_tokens_per_bs - local = compute_local_num_token_non_padded( - global_num_token_non_padded=forward_batch.num_token_non_padded, - num_tokens_per_dp=num_tokens_per_dp, - ) - self.num_token_non_padded.copy_(local) - else: - self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) - - # Pipeline-parallel proxy tensors. - if pp_proxy_tensors is not None and self.pp_proxy_tensors is not None: - for key, buf in self.pp_proxy_tensors.items(): - src = pp_proxy_tensors.tensors[key] - dim = src.shape[0] - buf[:dim].copy_(src) - - return seq_lens_cpu + assert isinstance( + buffer, torch.Tensor + ), f"Field {name} is expected to be a torch.Tensor or a dict of torch.Tensor, but got {type(buffer)}." + new_buffer = self._share_one_buffer(name, buffer) + setattr(self, name, new_buffer) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 094b1d317eeb..32208394a820 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -117,6 +117,7 @@ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner from sglang.srt.model_executor.cuda_graph_runner import ( CudaGraphRunner, + DecodeInputBuffers, set_torch_compile_config, ) from sglang.srt.model_executor.forward_batch_info import ( @@ -126,7 +127,6 @@ PPProxyTensors, ) from sglang.srt.model_executor.hook_manager import register_forward_hooks -from sglang.srt.model_executor.input_buffers import GraphInputBuffers from sglang.srt.model_executor.model_runner_kv_cache_mixin import ( ModelRunnerKVCacheMixin, ) @@ -1902,7 +1902,7 @@ def _dummy_run(self, batch_size: int): if require_gathered_buffer(self.server_args): assert require_mlp_tp_gather_ or require_attn_tp_gather(self.server_args) - buffers: GraphInputBuffers = GraphInputBuffers.create( + buffers: DecodeInputBuffers = DecodeInputBuffers.create( device=self.device, max_bs=batch_size, max_num_token=num_tokens, diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index a8ea3af57463..feccc5c5608b 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -19,7 +19,8 @@ import gc import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union import torch import tqdm @@ -55,6 +56,7 @@ ForwardMode, PPProxyTensors, ) +from sglang.srt.model_executor.input_buffers import ForwardInputBuffers from sglang.srt.utils import get_available_gpu_memory, is_npu, log_info_on_rank0 logger = logging.getLogger(__name__) @@ -63,6 +65,19 @@ from sglang.srt.model_executor.model_runner import ModelRunner +@dataclass +class PrefillInputBuffers(ForwardInputBuffers): + input_ids: torch.Tensor + out_cache_loc: torch.Tensor + out_cache_loc_swa: Optional[torch.Tensor] + mamba_track_indices: Optional[torch.Tensor] + mamba_track_mask: Optional[torch.Tensor] + mamba_track_seqlens: Optional[torch.Tensor] + positions: torch.Tensor + input_embeds: Optional[torch.Tensor] + mrope_positions: Optional[torch.Tensor] + + @contextmanager def freeze_gc(enable_cudagraph_gc: bool): """ @@ -189,31 +204,31 @@ def __init__(self, model_runner: ModelRunner): # Graph inputs with torch.device(self.device): - self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64) - self.out_cache_loc = torch.zeros( + input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64) + out_cache_loc = torch.zeros( (self.max_num_tokens,), dtype=self._cache_loc_dtype() ) - self.out_cache_loc_swa = ( + out_cache_loc_swa = ( torch.zeros((self.max_num_tokens,), dtype=torch.int64) if model_runner.is_hybrid_swa else None ) - self.mamba_track_indices = ( + mamba_track_indices = ( torch.zeros((self.max_bs,), dtype=torch.int64) if self.mamba_track_enabled else None ) - self.mamba_track_mask = ( + mamba_track_mask = ( torch.zeros((self.max_bs,), dtype=torch.bool) if self.mamba_track_enabled else None ) - self.mamba_track_seqlens = ( + mamba_track_seqlens = ( torch.zeros((self.max_bs,), dtype=torch.int32) if self.mamba_track_enabled else None ) - self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64) + positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64) self.tbo_plugin = TboCudaGraphRunnerPlugin() @@ -223,13 +238,29 @@ def __init__(self, model_runner: ModelRunner): # 1. In multimodal, we only compile and capture the language model part. # 2. The embedder is outside of the graph, but cuda graph requires the input embeds to have a fixed memory address. # 3. Input embeds is a pre-allocated buffer. In model.forward, we copy the embed output to this buffer. - self.input_embeds = torch.zeros( + input_embeds = torch.zeros( (self.max_num_tokens, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype, ) - self.mrope_positions = torch.zeros( + mrope_positions = torch.zeros( (3, self.max_num_tokens), dtype=torch.int64 ) + else: + input_embeds = None + mrope_positions = None + + self.buffers = PrefillInputBuffers( + input_ids=input_ids, + out_cache_loc=out_cache_loc, + out_cache_loc_swa=out_cache_loc_swa, + mamba_track_indices=mamba_track_indices, + mamba_track_mask=mamba_track_mask, + mamba_track_seqlens=mamba_track_seqlens, + positions=positions, + input_embeds=input_embeds, + mrope_positions=mrope_positions, + ) + self.buffers.share_buffers() self.attention_layers = self.model_runner.attention_layers self.moe_layers = self.model_runner.moe_layers @@ -285,29 +316,32 @@ def __init__(self, model_runner: ModelRunner): def warmup_torch_compile(self, num_tokens: int): """Warmup the model with a simple forward pass before CUDA graph capture.""" - input_ids = self.input_ids[:num_tokens] - input_embeds = self.input_embeds[:num_tokens] if self.is_multimodal else None - positions = self.positions[:num_tokens] + buffers = self.buffers + input_ids = buffers.input_ids[:num_tokens] + input_embeds = buffers.input_embeds[:num_tokens] if self.is_multimodal else None + positions = buffers.positions[:num_tokens] mrope_positions = ( - self.mrope_positions[:, :num_tokens] if self.is_multimodal else None + buffers.mrope_positions[:, :num_tokens] if self.is_multimodal else None ) - out_cache_loc = self.out_cache_loc[:num_tokens] + out_cache_loc = buffers.out_cache_loc[:num_tokens] out_cache_loc_swa = ( - self.out_cache_loc_swa[:num_tokens] - if self.out_cache_loc_swa is not None + buffers.out_cache_loc_swa[:num_tokens] + if buffers.out_cache_loc_swa is not None else None ) mamba_track_indices = ( - self.mamba_track_indices[:1] - if self.mamba_track_indices is not None + buffers.mamba_track_indices[:1] + if buffers.mamba_track_indices is not None else None ) mamba_track_mask = ( - self.mamba_track_mask[:1] if self.mamba_track_mask is not None else None + buffers.mamba_track_mask[:1] + if buffers.mamba_track_mask is not None + else None ) mamba_track_seqlens = ( - self.mamba_track_seqlens[:1] - if self.mamba_track_seqlens is not None + buffers.mamba_track_seqlens[:1] + if buffers.mamba_track_seqlens is not None else None ) with torch.device(self.device): @@ -422,34 +456,37 @@ def capture(self) -> None: self.capture_one_batch_size(num_tokens) def capture_one_batch_size(self, num_tokens: int): + buffers = self.buffers bs = 1 # Graph inputs - input_ids = self.input_ids[:num_tokens] - input_embeds = self.input_embeds[:num_tokens] if self.is_multimodal else None + input_ids = buffers.input_ids[:num_tokens] + input_embeds = buffers.input_embeds[:num_tokens] if self.is_multimodal else None - out_cache_loc = self.out_cache_loc[:num_tokens] + out_cache_loc = buffers.out_cache_loc[:num_tokens] out_cache_loc_swa = ( - self.out_cache_loc_swa[:num_tokens] - if self.out_cache_loc_swa is not None + buffers.out_cache_loc_swa[:num_tokens] + if buffers.out_cache_loc_swa is not None else None ) mamba_track_indices = ( - self.mamba_track_indices[:bs] - if self.mamba_track_indices is not None + buffers.mamba_track_indices[:bs] + if buffers.mamba_track_indices is not None else None ) mamba_track_mask = ( - self.mamba_track_mask[:bs] if self.mamba_track_mask is not None else None + buffers.mamba_track_mask[:bs] + if buffers.mamba_track_mask is not None + else None ) mamba_track_seqlens = ( - self.mamba_track_seqlens[:bs] - if self.mamba_track_seqlens is not None + buffers.mamba_track_seqlens[:bs] + if buffers.mamba_track_seqlens is not None else None ) - positions = self.positions[:num_tokens] + positions = buffers.positions[:num_tokens] mrope_positions = ( - self.mrope_positions[:, :num_tokens] if self.is_multimodal else None + buffers.mrope_positions[:, :num_tokens] if self.is_multimodal else None ) global_dp_buffer_len = None @@ -553,82 +590,85 @@ def replay_prepare( forward_batch: ForwardBatch, **kwargs, ): + buffers = self.buffers num_tokens = len(forward_batch.input_ids) index = bisect.bisect_left(self.capture_num_tokens, num_tokens) static_num_tokens = self.capture_num_tokens[index] self.raw_num_tokens = num_tokens if static_num_tokens != num_tokens: - self.out_cache_loc.zero_() - if self.out_cache_loc_swa is not None: - self.out_cache_loc_swa.zero_() - self.input_ids[num_tokens:static_num_tokens].zero_() - self.positions[num_tokens:static_num_tokens].zero_() + buffers.out_cache_loc.zero_() + if buffers.out_cache_loc_swa is not None: + buffers.out_cache_loc_swa.zero_() + buffers.input_ids[num_tokens:static_num_tokens].zero_() + buffers.positions[num_tokens:static_num_tokens].zero_() if self.is_multimodal: - self.input_embeds[:, num_tokens:static_num_tokens].zero_() + buffers.input_embeds[:, num_tokens:static_num_tokens].zero_() if forward_batch.mrope_positions is not None: - self.mrope_positions[:, num_tokens:static_num_tokens].zero_() + buffers.mrope_positions[:, num_tokens:static_num_tokens].zero_() bs = forward_batch.batch_size - self.input_ids[:num_tokens].copy_(forward_batch.input_ids) - self.positions[:num_tokens].copy_(forward_batch.positions) - self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) - if self.out_cache_loc_swa is not None: - self.out_cache_loc_swa[: self.raw_num_tokens].copy_( + buffers.input_ids[:num_tokens].copy_(forward_batch.input_ids) + buffers.positions[:num_tokens].copy_(forward_batch.positions) + buffers.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) + if buffers.out_cache_loc_swa is not None: + buffers.out_cache_loc_swa[: self.raw_num_tokens].copy_( self.model_runner.token_to_kv_pool_allocator.translate_loc_from_full_to_swa( forward_batch.out_cache_loc ) ) if ( - self.mamba_track_indices is not None + buffers.mamba_track_indices is not None and forward_batch.mamba_track_indices is not None ): - self.mamba_track_indices[:bs].copy_(forward_batch.mamba_track_indices) + buffers.mamba_track_indices[:bs].copy_(forward_batch.mamba_track_indices) if ( - self.mamba_track_mask is not None + buffers.mamba_track_mask is not None and forward_batch.mamba_track_mask is not None ): - self.mamba_track_mask[:bs].copy_(forward_batch.mamba_track_mask) + buffers.mamba_track_mask[:bs].copy_(forward_batch.mamba_track_mask) if ( - self.mamba_track_seqlens is not None + buffers.mamba_track_seqlens is not None and forward_batch.mamba_track_seqlens is not None ): - self.mamba_track_seqlens[:bs].copy_(forward_batch.mamba_track_seqlens) + buffers.mamba_track_seqlens[:bs].copy_(forward_batch.mamba_track_seqlens) - input_ids = self.input_ids[:static_num_tokens] - positions = self.positions[:static_num_tokens] - out_cache_loc = self.out_cache_loc[:static_num_tokens] + input_ids = buffers.input_ids[:static_num_tokens] + positions = buffers.positions[:static_num_tokens] + out_cache_loc = buffers.out_cache_loc[:static_num_tokens] out_cache_loc_swa = ( - self.out_cache_loc_swa[:static_num_tokens] + buffers.out_cache_loc_swa[:static_num_tokens] if forward_batch.out_cache_loc_swa is not None else None ) mamba_track_indices = ( - self.mamba_track_indices[:bs] - if self.mamba_track_indices is not None + buffers.mamba_track_indices[:bs] + if buffers.mamba_track_indices is not None else None ) mamba_track_mask = ( - self.mamba_track_mask[:bs] if self.mamba_track_mask is not None else None + buffers.mamba_track_mask[:bs] + if buffers.mamba_track_mask is not None + else None ) mamba_track_seqlens = ( - self.mamba_track_seqlens[:bs] - if self.mamba_track_seqlens is not None + buffers.mamba_track_seqlens[:bs] + if buffers.mamba_track_seqlens is not None else None ) if forward_batch.mrope_positions is not None: - self.mrope_positions[:, :num_tokens].copy_(forward_batch.mrope_positions) + buffers.mrope_positions[:, :num_tokens].copy_(forward_batch.mrope_positions) - input_ids = self.input_ids[:static_num_tokens] + input_ids = buffers.input_ids[:static_num_tokens] input_embeds = ( - self.input_embeds[:static_num_tokens] if self.is_multimodal else None + buffers.input_embeds[:static_num_tokens] if self.is_multimodal else None ) mrope_positions = ( - self.mrope_positions[:, :static_num_tokens] + buffers.mrope_positions[:, :static_num_tokens] if forward_batch.mrope_positions is not None else None ) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 5fe45086ca4a..a80c38ffa701 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -1,7 +1,8 @@ from __future__ import annotations import bisect -from typing import TYPE_CHECKING, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional import torch @@ -22,6 +23,7 @@ ForwardBatch, ForwardMode, ) +from sglang.srt.model_executor.input_buffers import ForwardInputBuffers from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.utils import ( require_attn_tp_gather, @@ -34,6 +36,23 @@ from sglang.srt.speculative.eagle_worker import EAGLEWorker +@dataclass +class EagleDraftInputBuffers(ForwardInputBuffers): + input_ids: torch.Tensor + req_pool_indices: torch.Tensor + out_cache_loc: torch.Tensor + positions: torch.Tensor + mrope_positions: torch.Tensor + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor + extend_seq_lens: torch.Tensor + topk_p: torch.Tensor + topk_index: torch.Tensor + hidden_states: torch.Tensor + global_num_tokens_gpu: Optional[torch.Tensor] + global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] + + class EAGLEDraftCudaGraphRunner: def __init__(self, eagle_worker: EAGLEWorker): # Parse args @@ -75,7 +94,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ 0 ].get_cuda_graph_seq_len_fill_value() - self.seq_lens_cpu = torch.full( + seq_lens_cpu = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs @@ -85,44 +104,59 @@ def __init__(self, eagle_worker: EAGLEWorker): # Graph inputs with torch.device(model_runner.device): - self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) - self.out_cache_loc = torch.zeros( + input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) + out_cache_loc = torch.zeros( (self.max_num_token * self.speculative_num_steps,), dtype=self._cache_loc_dtype(), ) - self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.mrope_positions = torch.zeros( - (3, self.max_num_token), dtype=torch.int64 - ) - self.seq_lens = torch.full( + positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + mrope_positions = torch.zeros((3, self.max_num_token), dtype=torch.int64) + seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) - self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) - self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) - self.hidden_states = torch.zeros( + extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) + topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) + topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) + hidden_states = torch.zeros( (self.max_bs, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype, ) if self.require_gathered_buffer: if self.require_mlp_tp_gather: - self.global_num_tokens_gpu = torch.zeros( + global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) - self.global_num_tokens_for_logprob_gpu = torch.zeros( + global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) else: assert self.require_attn_tp_gather - self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) - self.global_num_tokens_for_logprob_gpu = torch.zeros( + global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) else: - self.global_num_tokens_gpu = None - self.global_num_tokens_for_logprob_gpu = None + global_num_tokens_gpu = None + global_num_tokens_for_logprob_gpu = None + + self.buffers = EagleDraftInputBuffers( + input_ids=input_ids, + req_pool_indices=req_pool_indices, + out_cache_loc=out_cache_loc, + positions=positions, + mrope_positions=mrope_positions, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + extend_seq_lens=extend_seq_lens, + topk_p=topk_p, + topk_index=topk_index, + hidden_states=hidden_states, + global_num_tokens_gpu=global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu, + ) + self.buffers.share_buffers() # Capture try: @@ -181,59 +215,60 @@ def capture(self): def capture_one_batch_size( self, num_seqs: int, forward: Callable, stream_idx: int = 0 ): + buffers = self.buffers graph = self._create_graph() stream = self.stream num_tokens = num_seqs * self.num_tokens_per_bs # Graph inputs - req_pool_indices = self.req_pool_indices[:num_seqs] - seq_lens = self.seq_lens[:num_seqs] - seq_lens_cpu = self.seq_lens_cpu[:num_seqs] - extend_seq_lens = self.extend_seq_lens[:num_seqs] + req_pool_indices = buffers.req_pool_indices[:num_seqs] + seq_lens = buffers.seq_lens[:num_seqs] + seq_lens_cpu = buffers.seq_lens_cpu[:num_seqs] + extend_seq_lens = buffers.extend_seq_lens[:num_seqs] extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs] - out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps] - positions = self.positions[:num_tokens] - mrope_positions = self.mrope_positions[:, :num_tokens] - hidden_states = self.hidden_states[:num_seqs] - topk_p = self.topk_p[:num_seqs] - topk_index = self.topk_index[:num_seqs] + out_cache_loc = buffers.out_cache_loc[: num_tokens * self.speculative_num_steps] + positions = buffers.positions[:num_tokens] + mrope_positions = buffers.mrope_positions[:, :num_tokens] + hidden_states = buffers.hidden_states[:num_seqs] + topk_p = buffers.topk_p[:num_seqs] + topk_index = buffers.topk_index[:num_seqs] if self.require_mlp_tp_gather: - self.global_num_tokens_gpu.copy_( + buffers.global_num_tokens_gpu.copy_( torch.tensor( [num_tokens] * self.dp_size, dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) - self.global_num_tokens_for_logprob_gpu.copy_( + buffers.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( [num_tokens] * self.dp_size, dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu + global_num_tokens = buffers.global_num_tokens_gpu global_dp_buffer_len = num_tokens * self.dp_size - global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu + global_num_tokens_for_logprob = buffers.global_num_tokens_for_logprob_gpu elif self.require_attn_tp_gather: - self.global_num_tokens_gpu.copy_( + buffers.global_num_tokens_gpu.copy_( torch.tensor( [num_tokens], dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) - self.global_num_tokens_for_logprob_gpu.copy_( + buffers.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( [num_tokens], dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu + global_num_tokens = buffers.global_num_tokens_gpu global_dp_buffer_len = num_tokens - global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu + global_num_tokens_for_logprob = buffers.global_num_tokens_for_logprob_gpu else: global_num_tokens = None global_dp_buffer_len = None @@ -319,6 +354,7 @@ def _postprocess_output_to_raw_bs(self, out, raw_bs): def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None self.deepep_adapter.replay() + buffers = self.buffers raw_bs = forward_batch.batch_size raw_num_token = raw_bs * self.num_tokens_per_bs @@ -338,40 +374,40 @@ def replay(self, forward_batch: ForwardBatch): bs = self.capture_bs[index] if bs != raw_bs: - self.seq_lens.fill_(self.seq_len_fill_value) - self.out_cache_loc.zero_() - self.positions.zero_() + buffers.seq_lens.fill_(self.seq_len_fill_value) + buffers.out_cache_loc.zero_() + buffers.positions.zero_() num_tokens = bs * self.num_tokens_per_bs # Common inputs - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_( + buffers.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + buffers.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_( forward_batch.out_cache_loc ) - self.positions[:raw_num_token].copy_(forward_batch.positions) - self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) - self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) - self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) - self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + buffers.positions[:raw_num_token].copy_(forward_batch.positions) + buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) + buffers.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) + buffers.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + buffers.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) # TODO(ch-wan): support num_token_non_padded if self.require_gathered_buffer: - self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) - self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) + buffers.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + buffers.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) # Attention backend if bs != raw_bs: forward_batch.batch_size = bs - forward_batch.seq_lens = self.seq_lens[:bs] - forward_batch.req_pool_indices = self.req_pool_indices[:bs] - forward_batch.positions = self.positions[:num_tokens] + forward_batch.seq_lens = buffers.seq_lens[:bs] + forward_batch.req_pool_indices = buffers.req_pool_indices[:bs] + forward_batch.positions = buffers.positions[:num_tokens] if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: - self.seq_lens_cpu.fill_(self.seq_len_fill_value) - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) - forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] + buffers.seq_lens_cpu.fill_(self.seq_len_fill_value) + buffers.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + forward_batch.seq_lens_cpu = buffers.seq_lens_cpu[:bs] self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( forward_batch, bs @@ -387,10 +423,10 @@ def replay(self, forward_batch: ForwardBatch): if bs != raw_bs: out = self._postprocess_output_to_raw_bs(out, raw_bs) forward_batch.batch_size = raw_bs - forward_batch.positions = self.positions[:raw_num_token] - forward_batch.seq_lens = self.seq_lens[:raw_bs] - forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs] + forward_batch.positions = buffers.positions[:raw_num_token] + forward_batch.seq_lens = buffers.seq_lens[:raw_bs] + forward_batch.req_pool_indices = buffers.req_pool_indices[:raw_bs] if forward_batch.seq_lens_cpu is not None: - forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs] + forward_batch.seq_lens_cpu = buffers.seq_lens_cpu[:raw_bs] return out diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index e1afdd84b547..5ebdb8042565 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -1,7 +1,8 @@ from __future__ import annotations import bisect -from typing import TYPE_CHECKING, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional import torch @@ -23,6 +24,7 @@ ForwardBatch, ForwardMode, ) +from sglang.srt.model_executor.input_buffers import ForwardInputBuffers from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( @@ -36,6 +38,23 @@ from sglang.srt.speculative.eagle_worker import EAGLEWorker +@dataclass +class EagleDraftExtendInputBuffers(ForwardInputBuffers): + input_ids: torch.Tensor + req_pool_indices: torch.Tensor + out_cache_loc: torch.Tensor + positions: torch.Tensor + mrope_positions: torch.Tensor + hidden_states: torch.Tensor + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor + extend_seq_lens: torch.Tensor + accept_length: torch.Tensor + next_token_logits_buffer: torch.Tensor + global_num_tokens_gpu: Optional[torch.Tensor] + global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] + + class EAGLEDraftExtendCudaGraphRunner: def __init__(self, eagle_worker: EAGLEWorker): # Parse args @@ -80,7 +99,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.seq_len_fill_value = ( self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value() ) - self.seq_lens_cpu = torch.full( + seq_lens_cpu = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs @@ -90,21 +109,19 @@ def __init__(self, eagle_worker: EAGLEWorker): # Graph inputs with torch.device(model_runner.device): - self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) - self.out_cache_loc = torch.ones( + input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) + out_cache_loc = torch.ones( (self.max_num_token,), dtype=self._cache_loc_dtype() ) - self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.mrope_positions = torch.zeros( - (3, self.max_num_token), dtype=torch.int64 - ) + positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + mrope_positions = torch.zeros((3, self.max_num_token), dtype=torch.int64) if ( self.eagle_worker.speculative_algorithm.is_eagle3() and self.eagle_worker.eagle_use_aux_hidden_state ): - self.hidden_states = torch.zeros( + hidden_states = torch.zeros( ( self.max_num_token, ( @@ -120,40 +137,40 @@ def __init__(self, eagle_worker: EAGLEWorker): dtype=self.model_runner.dtype, ) else: - self.hidden_states = torch.zeros( + hidden_states = torch.zeros( (self.max_num_token, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype, ) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) - self.seq_lens = torch.full( + seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.extend_seq_lens = torch.full( + extend_seq_lens = torch.full( (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 ) - self.accept_length = torch.full( + accept_length = torch.full( (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 ) if self.require_gathered_buffer: if self.require_mlp_tp_gather: - self.global_num_tokens_gpu = torch.zeros( + global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) - self.global_num_tokens_for_logprob_gpu = torch.zeros( + global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) else: assert self.require_attn_tp_gather - self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) - self.global_num_tokens_for_logprob_gpu = torch.zeros( + global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) else: - self.global_num_tokens_gpu = None - self.global_num_tokens_for_logprob_gpu = None + global_num_tokens_gpu = None + global_num_tokens_for_logprob_gpu = None if hasattr( self.model_runner.model_config.hf_config, "draft_vocab_size" @@ -166,7 +183,7 @@ def __init__(self, eagle_worker: EAGLEWorker): else: vocab_size = self.model_runner.model_config.vocab_size - self.next_token_logits_buffer = torch.zeros( + next_token_logits_buffer = torch.zeros( ( ( self.max_bs * self.num_tokens_per_bs @@ -178,6 +195,23 @@ def __init__(self, eagle_worker: EAGLEWorker): dtype=torch.float, ) + self.buffers = EagleDraftExtendInputBuffers( + input_ids=input_ids, + req_pool_indices=req_pool_indices, + out_cache_loc=out_cache_loc, + positions=positions, + mrope_positions=mrope_positions, + hidden_states=hidden_states, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + extend_seq_lens=extend_seq_lens, + accept_length=accept_length, + next_token_logits_buffer=next_token_logits_buffer, + global_num_tokens_gpu=global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu, + ) + self.buffers.share_buffers() + # Capture try: with model_capture_mode(): @@ -233,23 +267,24 @@ def capture(self): CudaGraphRunner.capture(self) def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0): + buffers = self.buffers graph = self._create_graph() stream = self.stream num_tokens = bs * self.num_tokens_per_bs # Graph inputs - input_ids = self.input_ids[:num_tokens] - req_pool_indices = self.req_pool_indices[:bs] - seq_lens = self.seq_lens[:bs] - seq_lens_cpu = self.seq_lens_cpu[:bs] - extend_seq_lens = self.extend_seq_lens[:bs] + input_ids = buffers.input_ids[:num_tokens] + req_pool_indices = buffers.req_pool_indices[:bs] + seq_lens = buffers.seq_lens[:bs] + seq_lens_cpu = buffers.seq_lens_cpu[:bs] + extend_seq_lens = buffers.extend_seq_lens[:bs] extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs] - out_cache_loc = self.out_cache_loc[:num_tokens] - positions = self.positions[:num_tokens] - mrope_positions = self.mrope_positions[:, :num_tokens] - hidden_states = self.hidden_states[:num_tokens] - accept_length = self.accept_length[:bs] - next_token_logits_buffer = self.next_token_logits_buffer[ + out_cache_loc = buffers.out_cache_loc[:num_tokens] + positions = buffers.positions[:num_tokens] + mrope_positions = buffers.mrope_positions[:, :num_tokens] + hidden_states = buffers.hidden_states[:num_tokens] + accept_length = buffers.accept_length[:bs] + next_token_logits_buffer = buffers.next_token_logits_buffer[ : bs if self.forward_mode == ForwardMode.DRAFT_EXTEND else num_tokens ] @@ -260,34 +295,34 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0 ) if self.require_mlp_tp_gather: - self.global_num_tokens_gpu.copy_( + buffers.global_num_tokens_gpu.copy_( torch.tensor( [num_tokens] * self.dp_size, dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) - self.global_num_tokens_for_logprob_gpu.copy_( + buffers.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( [num_tokens_for_logprob] * self.dp_size, dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) global_dp_buffer_len = num_tokens * self.dp_size elif self.require_attn_tp_gather: - self.global_num_tokens_gpu.copy_( + buffers.global_num_tokens_gpu.copy_( torch.tensor( [num_tokens], dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) - self.global_num_tokens_for_logprob_gpu.copy_( + buffers.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( [num_tokens_for_logprob], dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) global_dp_buffer_len = num_tokens @@ -320,8 +355,8 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0 return_logprob=False, positions=positions, mrope_positions=mrope_positions, - global_num_tokens_gpu=self.global_num_tokens_gpu, - global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, + global_num_tokens_gpu=buffers.global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=buffers.global_num_tokens_for_logprob_gpu, dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), global_dp_buffer_len=global_dp_buffer_len, spec_algorithm=self.model_runner.spec_algorithm, @@ -380,6 +415,7 @@ def run_once(): def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None self.deepep_adapter.replay() + buffers = self.buffers # batch_size and num_seqs can be different in case there are finished examples # in the batch, which will not be counted as num_seqs @@ -398,45 +434,47 @@ def replay(self, forward_batch: ForwardBatch): bs = self.capture_bs[index] if bs * self.num_tokens_per_bs != num_tokens: - self.seq_lens.fill_(self.seq_len_fill_value) - self.out_cache_loc.zero_() - self.positions.zero_() - self.accept_length.fill_(self.num_tokens_per_bs) - self.extend_seq_lens.fill_(self.num_tokens_per_bs) + buffers.seq_lens.fill_(self.seq_len_fill_value) + buffers.out_cache_loc.zero_() + buffers.positions.zero_() + buffers.accept_length.fill_(self.num_tokens_per_bs) + buffers.extend_seq_lens.fill_(self.num_tokens_per_bs) # Common inputs - self.input_ids[:num_tokens].copy_(forward_batch.input_ids) - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + buffers.input_ids[:num_tokens].copy_(forward_batch.input_ids) + buffers.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) if forward_batch.extend_seq_lens is not None: - self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) + buffers.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) else: - self.extend_seq_lens[:raw_bs].fill_(self.num_tokens_per_bs) - self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) - self.positions[:num_tokens].copy_(forward_batch.positions) + buffers.extend_seq_lens[:raw_bs].fill_(self.num_tokens_per_bs) + buffers.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) + buffers.positions[:num_tokens].copy_(forward_batch.positions) if ( forward_batch.spec_info.hidden_states.shape[1] - == self.hidden_states.shape[1] + == buffers.hidden_states.shape[1] ): - self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) + buffers.hidden_states[:num_tokens].copy_( + forward_batch.spec_info.hidden_states + ) if forward_batch.spec_info.accept_length is not None: - self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) - self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + buffers.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) + buffers.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) # TODO(ch-wan): support num_token_non_padded if self.require_gathered_buffer: - self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + buffers.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) # V1: pruned_states = bs; V2: pruned_states = num_tokens if self.forward_mode.is_draft_extend_v2(): - self.global_num_tokens_for_logprob_gpu.fill_( + buffers.global_num_tokens_for_logprob_gpu.fill_( bs * self.num_tokens_per_bs ) else: - self.global_num_tokens_for_logprob_gpu.fill_(bs) + buffers.global_num_tokens_for_logprob_gpu.fill_(bs) if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: - self.seq_lens_cpu.fill_(self.seq_len_fill_value) - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + buffers.seq_lens_cpu.fill_(self.seq_len_fill_value) + buffers.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) if forward_batch.extend_seq_lens_cpu is not None: self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu @@ -449,22 +487,22 @@ def replay(self, forward_batch: ForwardBatch): forward_batch.spec_info.extend_seq_lens_cpu = list( self.extend_seq_lens_cpu[:bs] ) - forward_batch.spec_info.extend_seq_lens_tensor = self.extend_seq_lens[:bs] + forward_batch.spec_info.extend_seq_lens_tensor = buffers.extend_seq_lens[:bs] if bs != raw_bs: - forward_batch.spec_info.positions = self.positions[:num_tokens] - forward_batch.spec_info.accept_length = self.accept_length[:bs] + forward_batch.spec_info.positions = buffers.positions[:num_tokens] + forward_batch.spec_info.accept_length = buffers.accept_length[:bs] self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( bs=bs, - req_pool_indices=self.req_pool_indices, - seq_lens=self.seq_lens, + req_pool_indices=buffers.req_pool_indices, + seq_lens=buffers.seq_lens, seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, encoder_lens=None, forward_mode=self.forward_mode, spec_info=forward_batch.spec_info, - seq_lens_cpu=self.seq_lens_cpu, + seq_lens_cpu=buffers.seq_lens_cpu, ) # Replay @@ -477,7 +515,7 @@ def replay(self, forward_batch: ForwardBatch): # DRAFT_EXTEND_V2: all tokens calculations whether accepted or not. unpadding_bs = num_tokens elif bs != raw_bs: - forward_batch.spec_info.accept_length = self.accept_length[:raw_bs] + forward_batch.spec_info.accept_length = buffers.accept_length[:raw_bs] unpadding_bs = raw_bs else: unpadding_bs = None diff --git a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py index 94fb58b1de72..cb2974cff9e5 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py @@ -17,7 +17,8 @@ import bisect import logging import time -from typing import TYPE_CHECKING, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, List, Optional import torch @@ -39,6 +40,7 @@ ForwardBatch, ForwardMode, ) +from sglang.srt.model_executor.input_buffers import ForwardInputBuffers from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.multi_layer_eagle_utils import assign_new_state_triton from sglang.srt.speculative.spec_utils import fast_topk @@ -59,6 +61,28 @@ logger = logging.getLogger(__name__) +@dataclass +class MultiLayerEagleDraftExtendInputBuffers(ForwardInputBuffers): + # Sliced from shared parent buffers + input_ids: torch.Tensor + out_cache_loc: torch.Tensor + swa_out_cache_loc: torch.Tensor + positions: torch.Tensor + # Shared from parent + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor + req_pool_indices: torch.Tensor + accept_length: torch.Tensor + # Per-step buffers + extend_seq_lens: torch.Tensor + extend_start_loc: torch.Tensor + mrope_positions: torch.Tensor + hidden_states: torch.Tensor + next_token_logits_buffer: torch.Tensor + global_num_tokens_gpu: Optional[torch.Tensor] + global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] + + class MultiLayerEagleDraftExtendCudaGraphRunner: def __init__(self, eagle_worker: MultiLayerEagleDraftWorker, step: int): # Parse args @@ -109,7 +133,7 @@ def init_buffers_and_capture( next_cuda_graph_runner, ): self.next_cuda_graph_runner = next_cuda_graph_runner - self.seq_lens_cpu = cuda_graph_buffers["seq_lens_cpu"] + seq_lens_cpu = cuda_graph_buffers["seq_lens_cpu"] self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs if self.enable_torch_compile: @@ -119,62 +143,60 @@ def init_buffers_and_capture( with torch.device(self.model_runner.device): # sliced buffers # slice according to max_num_token - self.input_ids = cuda_graph_buffers["input_ids"][ + input_ids = cuda_graph_buffers["input_ids"][ offset : offset + self.max_num_token ] - self.out_cache_loc = cuda_graph_buffers["out_cache_loc"][ + out_cache_loc = cuda_graph_buffers["out_cache_loc"][ offset : offset + self.max_num_token ] - self.swa_out_cache_loc = cuda_graph_buffers["swa_out_cache_loc"][ + swa_out_cache_loc = cuda_graph_buffers["swa_out_cache_loc"][ offset : offset + self.max_num_token ] - self.positions = cuda_graph_buffers["positions"][ + positions = cuda_graph_buffers["positions"][ offset : offset + self.max_num_token ] # shared states - self.seq_lens = cuda_graph_buffers["seq_lens"] - self.req_pool_indices = cuda_graph_buffers["req_pool_indices"] - self.accept_length = cuda_graph_buffers["accept_length"] + seq_lens = cuda_graph_buffers["seq_lens"] + req_pool_indices = cuda_graph_buffers["req_pool_indices"] + accept_length = cuda_graph_buffers["accept_length"] - self.extend_seq_lens = torch.full( + extend_seq_lens = torch.full( (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32, ) - self.extend_start_loc = torch.arange( + extend_start_loc = torch.arange( 0, self.max_bs * self.num_tokens_per_bs, step=self.num_tokens_per_bs, dtype=torch.int32, ) - self.mrope_positions = torch.zeros( - (3, self.max_num_token), dtype=torch.int64 - ) + mrope_positions = torch.zeros((3, self.max_num_token), dtype=torch.int64) - self.hidden_states = torch.zeros( + hidden_states = torch.zeros( (self.max_num_token, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype, ) if self.require_gathered_buffer: if self.require_mlp_tp_gather: - self.global_num_tokens_gpu = torch.zeros( + global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) - self.global_num_tokens_for_logprob_gpu = torch.zeros( + global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) else: assert self.require_attn_tp_gather - self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) - self.global_num_tokens_for_logprob_gpu = torch.zeros( + global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) else: - self.global_num_tokens_gpu = None - self.global_num_tokens_for_logprob_gpu = None + global_num_tokens_gpu = None + global_num_tokens_for_logprob_gpu = None if hasattr( self.model_runner.model_config.hf_config, "draft_vocab_size" @@ -187,7 +209,7 @@ def init_buffers_and_capture( else: vocab_size = self.model_runner.model_config.vocab_size - self.next_token_logits_buffer = torch.zeros( + next_token_logits_buffer = torch.zeros( ( ( self.max_bs * self.num_tokens_per_bs @@ -199,6 +221,25 @@ def init_buffers_and_capture( dtype=torch.float, ) + self.buffers = MultiLayerEagleDraftExtendInputBuffers( + input_ids=input_ids, + out_cache_loc=out_cache_loc, + swa_out_cache_loc=swa_out_cache_loc, + positions=positions, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + req_pool_indices=req_pool_indices, + accept_length=accept_length, + extend_seq_lens=extend_seq_lens, + extend_start_loc=extend_start_loc, + mrope_positions=mrope_positions, + hidden_states=hidden_states, + next_token_logits_buffer=next_token_logits_buffer, + global_num_tokens_gpu=global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu, + ) + self.buffers.share_buffers() + # Capture try: with model_capture_mode(): @@ -250,54 +291,55 @@ def capture(self): CudaGraphRunner.capture(self) def get_forward_batch(self, bs: int) -> ForwardBatch: + buffers = self.buffers num_tokens = bs * self.num_tokens_per_bs # Graph inputs - input_ids = self.input_ids[:num_tokens] - req_pool_indices = self.req_pool_indices[:bs] - seq_lens = self.seq_lens[:bs] - seq_lens_cpu = self.seq_lens_cpu[:bs] - extend_seq_lens = self.extend_seq_lens[:bs] + input_ids = buffers.input_ids[:num_tokens] + req_pool_indices = buffers.req_pool_indices[:bs] + seq_lens = buffers.seq_lens[:bs] + seq_lens_cpu = buffers.seq_lens_cpu[:bs] + extend_seq_lens = buffers.extend_seq_lens[:bs] extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs] - extend_start_loc = self.extend_start_loc[:bs] - accept_length = self.accept_length[:bs] - out_cache_loc = self.out_cache_loc[:num_tokens] - positions = self.positions[:num_tokens] - mrope_positions = self.mrope_positions[:, :num_tokens] - hidden_states = self.hidden_states[:num_tokens] - next_token_logits_buffer = self.next_token_logits_buffer[ + extend_start_loc = buffers.extend_start_loc[:bs] + accept_length = buffers.accept_length[:bs] + out_cache_loc = buffers.out_cache_loc[:num_tokens] + positions = buffers.positions[:num_tokens] + mrope_positions = buffers.mrope_positions[:, :num_tokens] + hidden_states = buffers.hidden_states[:num_tokens] + next_token_logits_buffer = buffers.next_token_logits_buffer[ : bs if self.forward_mode == ForwardMode.DRAFT_EXTEND else num_tokens ] if self.require_mlp_tp_gather: - self.global_num_tokens_gpu.copy_( + buffers.global_num_tokens_gpu.copy_( torch.tensor( [num_tokens] * self.dp_size, dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) - self.global_num_tokens_for_logprob_gpu.copy_( + buffers.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( [num_tokens] * self.dp_size, dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) global_dp_buffer_len = num_tokens * self.dp_size elif self.require_attn_tp_gather: - self.global_num_tokens_gpu.copy_( + buffers.global_num_tokens_gpu.copy_( torch.tensor( [num_tokens], dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) - self.global_num_tokens_for_logprob_gpu.copy_( + buffers.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( [bs], dtype=torch.int32, - device=self.input_ids.device, + device=buffers.input_ids.device, ) ) global_dp_buffer_len = num_tokens @@ -326,8 +368,8 @@ def get_forward_batch(self, bs: int) -> ForwardBatch: return_logprob=False, positions=positions, mrope_positions=mrope_positions, - global_num_tokens_gpu=self.global_num_tokens_gpu, - global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, + global_num_tokens_gpu=buffers.global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=buffers.global_num_tokens_for_logprob_gpu, dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), global_dp_buffer_len=global_dp_buffer_len, spec_algorithm=self.model_runner.spec_algorithm, @@ -346,6 +388,7 @@ def get_forward_batch(self, bs: int) -> ForwardBatch: return forward_batch def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0): + buffers = self.buffers graph = self._create_graph() stream = self.stream @@ -390,7 +433,7 @@ def run_once(): select_index = ( torch.arange(bs, device=self.model_runner.device) * (self.speculative_num_draft_tokens + self.step) - + self.accept_length[:bs] + + buffers.accept_length[:bs] - 1 + self.step ) @@ -399,24 +442,25 @@ def run_once(): ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) if self.next_cuda_graph_runner is not None: + next_buffers = self.next_cuda_graph_runner.buffers padding_lens = ( - self.speculative_num_draft_tokens - self.accept_length[:bs] + self.speculative_num_draft_tokens - buffers.accept_length[:bs] ) assign_new_state_triton( ret.topk_index, - self.input_ids, - self.positions, - self.hidden_states, - self.out_cache_loc, - self.extend_seq_lens, - self.extend_start_loc, - self.next_cuda_graph_runner.input_ids, - self.next_cuda_graph_runner.positions, - self.next_cuda_graph_runner.hidden_states, - self.next_cuda_graph_runner.out_cache_loc, - self.next_cuda_graph_runner.extend_seq_lens, - self.next_cuda_graph_runner.extend_start_loc, - self.next_cuda_graph_runner.seq_lens, + buffers.input_ids, + buffers.positions, + buffers.hidden_states, + buffers.out_cache_loc, + buffers.extend_seq_lens, + buffers.extend_start_loc, + next_buffers.input_ids, + next_buffers.positions, + next_buffers.hidden_states, + next_buffers.out_cache_loc, + next_buffers.extend_seq_lens, + next_buffers.extend_start_loc, + next_buffers.seq_lens, padding_lens, forward_batch.batch_size, self.step, @@ -424,9 +468,9 @@ def run_once(): forward_batch.req_to_token_pool.req_to_token, self.eagle_worker.req_to_hidden_states_pool, ) - self.next_cuda_graph_runner.swa_out_cache_loc.copy_( + next_buffers.swa_out_cache_loc.copy_( self.model_runner.token_to_kv_pool.translate_loc_from_full_to_swa( - self.next_cuda_graph_runner.out_cache_loc + next_buffers.out_cache_loc ) ) @@ -446,27 +490,30 @@ def run_once(): def init_replay_state( self, forward_batch: ForwardBatch, bs: int, raw_bs: int, num_tokens: int ): + buffers = self.buffers # Common inputs - self.input_ids[:num_tokens].copy_(forward_batch.input_ids) - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + buffers.input_ids[:num_tokens].copy_(forward_batch.input_ids) + buffers.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) if forward_batch.extend_seq_lens is not None: - self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) - self.extend_start_loc[:raw_bs].copy_(forward_batch.extend_start_loc) - self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) - self.positions[:num_tokens].copy_(forward_batch.positions) + buffers.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) + buffers.extend_start_loc[:raw_bs].copy_(forward_batch.extend_start_loc) + buffers.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) + buffers.positions[:num_tokens].copy_(forward_batch.positions) if ( forward_batch.spec_info.hidden_states.shape[1] - == self.hidden_states.shape[1] + == buffers.hidden_states.shape[1] ): - self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) + buffers.hidden_states[:num_tokens].copy_( + forward_batch.spec_info.hidden_states + ) if forward_batch.spec_info.accept_length is not None: - self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) - self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + buffers.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) + buffers.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: - self.seq_lens_cpu.fill_(self.seq_len_fill_value) - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + buffers.seq_lens_cpu.fill_(self.seq_len_fill_value) + buffers.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) if forward_batch.extend_seq_lens_cpu is not None: self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu @@ -474,6 +521,7 @@ def init_replay_state( def replay(self, forward_batch: ForwardBatch, init_state: bool = True): assert forward_batch.out_cache_loc is not None self.deepep_adapter.replay() + buffers = self.buffers # batch_size and num_seqs can be different in case there are finished examples # in the batch, which will not be counted as num_seqs @@ -492,28 +540,28 @@ def replay(self, forward_batch: ForwardBatch, init_state: bool = True): self.init_replay_state(forward_batch, bs, raw_bs, num_tokens) if self.require_gathered_buffer: - self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) - self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) + buffers.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + buffers.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) - forward_batch.spec_info.hidden_states = self.hidden_states[:num_tokens] - forward_batch.spec_info.accept_length = self.accept_length[:bs] + forward_batch.spec_info.hidden_states = buffers.hidden_states[:num_tokens] + forward_batch.spec_info.accept_length = buffers.accept_length[:bs] forward_batch.spec_info.num_tokens_per_req = self.num_tokens_per_bs forward_batch.spec_info.num_tokens_for_logprob_per_req = 1 - forward_batch.spec_info.positions = self.positions[:num_tokens] - forward_batch.spec_info.extend_seq_lens_tensor = self.extend_seq_lens[:bs] + forward_batch.spec_info.positions = buffers.positions[:num_tokens] + forward_batch.spec_info.extend_seq_lens_tensor = buffers.extend_seq_lens[:bs] self.eagle_worker.draft_extend_attn_backend_list[ self.step ].init_forward_metadata_replay_cuda_graph( bs=bs, - req_pool_indices=self.req_pool_indices, - seq_lens=self.seq_lens, + req_pool_indices=buffers.req_pool_indices, + seq_lens=buffers.seq_lens, seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, encoder_lens=None, forward_mode=self.forward_mode, spec_info=forward_batch.spec_info, - seq_lens_cpu=self.seq_lens_cpu, + seq_lens_cpu=buffers.seq_lens_cpu, ) # Replay @@ -526,7 +574,7 @@ def replay(self, forward_batch: ForwardBatch, init_state: bool = True): # DRAFT_EXTEND_V2: all tokens calculations whether accepted or not. unpadding_bs = num_tokens elif bs != raw_bs: - forward_batch.spec_info.accept_length = self.accept_length[:raw_bs] + forward_batch.spec_info.accept_length = buffers.accept_length[:raw_bs] unpadding_bs = raw_bs else: unpadding_bs = None @@ -565,8 +613,8 @@ def _init_and_capture(self): self.runners = [None] * self.speculative_num_steps return - self.runners = [] - buffer_len_list = [] + self.runners: List[Optional[MultiLayerEagleDraftExtendCudaGraphRunner]] = [] + buffer_len_list: List[int] = [] # 1. Capture loop for step in range(self.speculative_num_steps): diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py index 7f660b9f085c..7ef46c5dfe4b 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py @@ -498,13 +498,13 @@ def _draft_extend_for_decode( self.cuda_graph_runner_for_draft_extend.get_last_runner() ) assign_hidden_states_pool_triton( - last_cuda_graph_runner.hidden_states, - last_cuda_graph_runner.req_pool_indices, + last_cuda_graph_runner.buffers.hidden_states, + last_cuda_graph_runner.buffers.req_pool_indices, self.req_to_hidden_states_pool, self.speculative_num_steps - 1, forward_batch.batch_size, - last_cuda_graph_runner.extend_seq_lens, - last_cuda_graph_runner.extend_start_loc, + last_cuda_graph_runner.buffers.extend_seq_lens, + last_cuda_graph_runner.buffers.extend_start_loc, ) # Reorganize the spec info for the next batch