From 29fd0b31d139c0c1326b53c474f58811a2d4883e Mon Sep 17 00:00:00 2001 From: ispobock Date: Sun, 25 May 2025 17:46:16 +0000 Subject: [PATCH 1/7] add eagle draft extend cuda graph Co-authored-by: Sehoon Kim --- .../eagle_draft_extend_cuda_graph_runner.py | 250 ++++++++++++++++++ python/sglang/srt/speculative/eagle_utils.py | 45 ++++ python/sglang/srt/speculative/eagle_worker.py | 37 ++- 3 files changed, 329 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py new file mode 100644 index 00000000000..5aff19a3443 --- /dev/null +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import bisect +from typing import TYPE_CHECKING, Callable + +import torch + +from sglang.srt.model_executor.cuda_graph_runner import ( + CudaGraphRunner, + LogitsProcessorOutput, + get_batch_sizes_to_capture, + get_global_graph_memory_pool, + set_global_graph_memory_pool, + set_torch_compile_config, +) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.speculative.eagle_utils import EagleDraftInput + +if TYPE_CHECKING: + from sglang.srt.speculative.eagle_worker import EAGLEWorker + + +class EAGLEDraftExtendCudaGraphRunner: + def __init__(self, eagle_worker: EAGLEWorker): + # Parse args + self.eagle_worker = eagle_worker + self.model_runner = model_runner = eagle_worker.model_runner + self.graphs = {} + self.output_buffers = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.tp_size = self.model_runner.tp_size + self.dp_size = model_runner.server_args.dp_size + self.speculative_num_steps = model_runner.server_args.speculative_num_steps + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + self.padded_static_len = -1 + + # Attention backend + self.num_tokens_per_bs = self.speculative_num_steps + 1 + self.max_bs = max(self.capture_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + + self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state( + self.max_num_token + ) + self.seq_len_fill_value = ( + self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value() + ) + self.seq_lens_cpu = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + + if self.enable_torch_compile: + set_torch_compile_config() + + # Graph inputs + with torch.device("cuda"): + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) + self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + + if self.eagle_worker.speculative_algorithm.is_eagle3(): + self.hidden_states = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size * 3, + ), + dtype=self.model_runner.dtype, + ) + else: + self.hidden_states = torch.zeros( + (self.max_num_token, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.dtype, + ) + + self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) + self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) + self.accept_length = torch.ones((self.max_bs,), dtype=torch.int32) + + # Capture + try: + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture CUDA graph failed: {e}\n" + "Possible solutions:\n" + "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + ) + + def can_run(self, forward_batch: ForwardBatch): + batch_size = forward_batch.seq_lens.numel() + + is_bs_supported = ( + batch_size in self.graphs + if self.disable_padding + else batch_size <= self.max_bs + ) + + return is_bs_supported + + def capture(self): + CudaGraphRunner.capture(self) + + def capture_one_batch_size(self, bs: int, forward: Callable): + graph = torch.cuda.CUDAGraph() + stream = self.stream + num_tokens = bs * self.num_tokens_per_bs + + # Graph inputs + input_ids = self.input_ids[:num_tokens] + req_pool_indices = self.req_pool_indices[:bs] + seq_lens = self.seq_lens[:bs] + extend_seq_lens = self.extend_seq_lens[:bs] + accept_length = self.accept_length[:bs] + out_cache_loc = self.out_cache_loc[:num_tokens] + positions = self.positions[:num_tokens] + hidden_states = self.hidden_states[:num_tokens] + + spec_info = EagleDraftInput( + hidden_states=hidden_states, + accept_length=accept_length, + ) + spec_info.positions = None + + # Forward batch + forward_batch = ForwardBatch( + forward_mode=ForwardMode.DRAFT_EXTEND, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum(), + return_logprob=False, + positions=positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=CaptureHiddenMode.LAST, + attn_backend=self.eagle_worker.draft_extend_attn_backend, + extend_seq_lens=extend_seq_lens, + padded_static_len=self.padded_static_len, + ) + + self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_capture_cuda_graph( + bs=bs, + num_tokens=num_tokens, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DRAFT_EXTEND, + spec_info=spec_info, + ) + + # Run and capture + def run_once(): + # Backup two fields, which will be modified in-place in `draft_forward`. + output_cache_loc_backup = forward_batch.out_cache_loc + hidden_states_backup = forward_batch.spec_info.hidden_states + + ret = self.eagle_worker.draft_model_runner.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + + forward_batch.out_cache_loc = output_cache_loc_backup + forward_batch.spec_info.hidden_states = hidden_states_backup + return ret + + for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + run_once() + + with torch.cuda.graph( + graph, pool=get_global_graph_memory_pool(), stream=stream + ): + out = run_once() + + set_global_graph_memory_pool(graph.pool()) + return graph, out + + def replay(self, forward_batch: ForwardBatch): + assert forward_batch.out_cache_loc is not None + # batch_size and num_seqs can be different in case there are finished examples + # in the batch, which will not be counted as num_seqs + raw_bs = forward_batch.seq_lens.numel() + num_tokens = forward_batch.input_ids.numel() + assert raw_bs * self.num_tokens_per_bs == num_tokens + + index = bisect.bisect_left(self.capture_bs, raw_bs) + bs = self.capture_bs[index] + if bs != raw_bs: + self.accept_length.fill_(1) + self.out_cache_loc.zero_() + + # Common inputs + self.input_ids[:num_tokens].copy_(forward_batch.input_ids) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens[:raw_bs]) + self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens[:raw_bs]) + self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) + self.positions[:num_tokens].copy_(forward_batch.positions) + self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) + self.accept_length[:raw_bs].copy_( + forward_batch.spec_info.accept_length[:raw_bs] + ) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + + if forward_batch.seq_lens_cpu is not None: + if bs != raw_bs: + self.seq_lens_cpu.fill_(1) + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + + forward_batch.spec_info.positions = None + if bs != raw_bs: + forward_batch.spec_info.accept_length = self.accept_length[:bs] + + self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=self.req_pool_indices, + seq_lens=self.seq_lens, + seq_lens_sum=sum(forward_batch.seq_lens), + encoder_lens=None, + forward_mode=ForwardMode.DRAFT_EXTEND, + spec_info=forward_batch.spec_info, + seq_lens_cpu=self.seq_lens_cpu, + ) + + # Replay + self.graphs[bs].replay() + out = self.output_buffers[bs] + if bs != raw_bs: + forward_batch.spec_info.accept_length = self.accept_length[:raw_bs] + out = LogitsProcessorOutput( + next_token_logits=out.next_token_logits[:raw_bs], + hidden_states=out.hidden_states[:raw_bs], + ) + return out diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index eb1b3b44f2f..2140d1ac163 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -84,6 +84,7 @@ def prepare_extend_after_decode( self, batch: ScheduleBatch, speculative_num_steps: int, + pad_input: bool = False, ): assert len(self.verified_id) == len(batch.out_cache_loc) accept_length_cpu = batch.spec_info.accept_length_cpu @@ -111,6 +112,50 @@ def prepare_extend_after_decode( batch.input_ids = self.verified_id self.verified_id = new_verified_id + if pad_input: + batch_size = sum(not req.finished() for req in batch.reqs) + # Total constant input length after padding + static_len = speculative_num_steps + 1 + # Total size after padding + padded_input_size = batch_size * static_len + + padded_len = padded_input_size - batch.input_ids.shape[0] + if padded_len > 0: + new_input_ids = torch.nn.functional.pad( + batch.input_ids, (0, padded_len), value=0 + ) + position_padding = torch.arange( + padded_len, device=self.positions.device + ) + new_positions = torch.cat([self.positions, position_padding]) + + # need dummy hidden states for the padded positions + hidden_states_dim = self.hidden_states.shape[-1] + new_hidden_states = torch.cat( + [ + self.hidden_states, + torch.zeros( + (padded_len, hidden_states_dim), + dtype=self.hidden_states.dtype, + device=self.hidden_states.device, + ), + ], + dim=0, + ) + + # allocate KV cache location for the padded tokens + padded_cache_loc = torch.zeros( + padded_len, + dtype=batch.out_cache_loc.dtype, + device=batch.out_cache_loc.device, + ) + new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc]) + + batch.input_ids = new_input_ids + self.hidden_states = new_hidden_states + self.positions = new_positions + batch.out_cache_loc = new_out_cache_loc + def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 647fafaadb1..86a8df534d8 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -26,6 +26,9 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, ) +from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import ( + EAGLEDraftExtendCudaGraphRunner, +) from sglang.srt.speculative.eagle_utils import ( EagleDraftInput, EagleVerifyInput, @@ -189,6 +192,7 @@ def init_attention_backend(self): self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "fa3": from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, FlashAttentionMultiStepBackend, ) @@ -197,7 +201,10 @@ def init_attention_backend(self): self.topk, self.speculative_num_steps, ) - self.draft_extend_attn_backend = None + self.draft_extend_attn_backend = FlashAttentionBackend( + self.draft_model_runner, + skip_prefill=False, + ) self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "flashmla": @@ -242,7 +249,18 @@ def init_cuda_graphs(self): # Capture extend if self.draft_extend_attn_backend: - raise NotImplementedError() + tic = time.perf_counter() + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" + ) + self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner( + self + ) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." + ) @property def draft_model_runner(self): @@ -656,6 +674,7 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): batch.spec_info.prepare_extend_after_decode( batch, self.speculative_num_steps, + pad_input=self.cuda_graph_runner_for_draft_extend is not None, ) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.return_logprob = False @@ -665,7 +684,19 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): ) # Run - logits_output, _ = self.draft_model_runner.forward(forward_batch) + can_cuda_graph = ( + self.cuda_graph_runner_for_draft_extend + and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) + ) + if can_cuda_graph: + logits_output = self.cuda_graph_runner_for_draft_extend.replay( + forward_batch + ) + else: + self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch) + logits_output = self.draft_model_runner.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) self._detect_nan_if_needed(logits_output) self.capture_for_decode(logits_output, forward_batch.spec_info) From 50a96f5a76f384721541a216f22e256d2048c7f9 Mon Sep 17 00:00:00 2001 From: ispobock Date: Sun, 25 May 2025 17:46:43 +0000 Subject: [PATCH 2/7] update for fa3 backend --- .../attention/flashattention_backend.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 9b47509b224..10378602722 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1268,6 +1268,29 @@ def init_cuda_graph_state(self, max_bs: int): ), } + self.draft_extend_metadata = { + "cache_seqlens": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "cu_seqlens_q": torch.zeros( + max_bs + 1, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + "page_table": torch.zeros( + max_bs, + (self.max_context_len + self.page_size - 1) // self.page_size, + dtype=torch.int32, + device=self.device, + ), + "strided_indices": torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ), + } + if self.topk > 1: self.target_verify_metadata_topk_normal = { "cache_seqlens": torch.zeros( @@ -1508,6 +1531,32 @@ def init_forward_metadata_capture_cuda_graph( self.target_verify_metadata_topk_normal[bs] = metadata self.target_verify_metadata_topk_expand[bs] = metadata_expand + elif forward_mode.is_draft_extend(): + metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][ + :bs + ] + metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) + + num_tokens_per_bs = num_tokens // bs + metadata.max_seq_len_q = num_tokens_per_bs + metadata.max_seq_len_k = seq_lens.max().item() + + metadata.cu_seqlens_q = torch.arange( + 0, + bs * num_tokens_per_bs + 1, + num_tokens_per_bs, + dtype=torch.int32, + device=device, + ) + + metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][ + : (bs + 1) + ] + metadata.page_table = self.draft_extend_metadata["page_table"][ + req_pool_indices, : + ] + + self.draft_extend_metadata[bs] = metadata if encoder_lens is not None: encoder_bs = encoder_lens.numel() @@ -1732,6 +1781,23 @@ def init_forward_metadata_replay_cuda_graph( metadata_expand.max_seq_len_k = ( metadata_expand.cache_seqlens_int32.max().item() ) + elif forward_mode.is_draft_extend(): + metadata = self.draft_extend_metadata[bs] + metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) + + metadata.max_seq_len_k = seq_lens_cpu.max().item() + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) + ) + max_seq_pages = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size + page_indices = self.req_to_token[ + req_pool_indices[:, None], + self.draft_extend_metadata["strided_indices"][:max_seq_pages], + ] + page_indices //= self.page_size + metadata.page_table[:, :max_seq_pages].copy_(page_indices) if encoder_lens is not None: # Only support encoder size 1 for now From 19aeb04b758833fa884ea8631c26aecd081bb57b Mon Sep 17 00:00:00 2001 From: ispobock Date: Mon, 26 May 2025 03:45:38 +0000 Subject: [PATCH 3/7] fix lint --- docs/references/performance_analysis_and_optimization.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/references/performance_analysis_and_optimization.rst b/docs/references/performance_analysis_and_optimization.rst index 1d70fb51d5d..76db62df7ad 100644 --- a/docs/references/performance_analysis_and_optimization.rst +++ b/docs/references/performance_analysis_and_optimization.rst @@ -4,4 +4,4 @@ Performance Analysis & Optimization :maxdepth: 1 benchmark_and_profiling.md - accuracy_evaluation.md \ No newline at end of file + accuracy_evaluation.md From 9055a49e89328fc0b26bb2972cf0320bb4250016 Mon Sep 17 00:00:00 2001 From: ispobock Date: Tue, 27 May 2025 03:14:24 +0000 Subject: [PATCH 4/7] fix performance issue --- .../eagle_draft_extend_cuda_graph_runner.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 5aff19a3443..3fd42737599 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -196,26 +196,25 @@ def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None # batch_size and num_seqs can be different in case there are finished examples # in the batch, which will not be counted as num_seqs - raw_bs = forward_batch.seq_lens.numel() - num_tokens = forward_batch.input_ids.numel() + raw_bs = forward_batch.batch_size + num_tokens = forward_batch.input_ids.shape[0] assert raw_bs * self.num_tokens_per_bs == num_tokens index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: + self.seq_lens.fill_(1) self.accept_length.fill_(1) self.out_cache_loc.zero_() # Common inputs self.input_ids[:num_tokens].copy_(forward_batch.input_ids) - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens[:raw_bs]) - self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens[:raw_bs]) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) self.positions[:num_tokens].copy_(forward_batch.positions) self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) - self.accept_length[:raw_bs].copy_( - forward_batch.spec_info.accept_length[:raw_bs] - ) + self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) if forward_batch.seq_lens_cpu is not None: @@ -231,7 +230,7 @@ def replay(self, forward_batch: ForwardBatch): bs=bs, req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens, - seq_lens_sum=sum(forward_batch.seq_lens), + seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs), encoder_lens=None, forward_mode=ForwardMode.DRAFT_EXTEND, spec_info=forward_batch.spec_info, From e6108022962224fae8b0ff69961e5a3013024880 Mon Sep 17 00:00:00 2001 From: ispobock Date: Tue, 27 May 2025 06:18:38 +0000 Subject: [PATCH 5/7] adjust reserve memory for cuda graph --- python/sglang/srt/server_args.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9dba0bb976b..8175094f305 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -262,10 +262,15 @@ def __post_init__(self): self.mem_fraction_static = 0.88 if gpu_mem is not None and gpu_mem > 96 * 1024: mem_fraction = self.mem_fraction_static + # 15 GB + additional 3GB for cuda graph + reserve_mem = 1024 * 18 + # need reserve more memory for spec cuda graph + if self.speculative_algorithm is not None: + reserve_mem = 1024 * 20 self.mem_fraction_static = min( mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem, - (gpu_mem - 1024 * 18) - / gpu_mem, # 15 GB + additional 3GB for cuda graph + (gpu_mem - reserve_mem) + / gpu_mem, ) # Set chunked prefill size, which depends on the gpu memory capacity From 6707ea9863b518e9300fb984c666f54b32e15bca Mon Sep 17 00:00:00 2001 From: ispobock Date: Tue, 27 May 2025 09:14:21 +0000 Subject: [PATCH 6/7] fix acc rate --- .../sglang/srt/layers/attention/flashattention_backend.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 10378602722..08a62c0dda8 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1789,6 +1789,12 @@ def init_forward_metadata_replay_cuda_graph( metadata.cu_seqlens_k[1:].copy_( torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) ) + accept_length = spec_info.accept_length[:bs] + metadata.max_seq_len_q = accept_length.max().item() + metadata.cu_seqlens_q[1:].copy_( + torch.cumsum(accept_length, dim=0, dtype=torch.int32) + ) + max_seq_pages = ( metadata.max_seq_len_k + self.page_size - 1 ) // self.page_size From cd4017555f1a06d34142754604b0b27f8baaaf96 Mon Sep 17 00:00:00 2001 From: ispobock Date: Tue, 27 May 2025 09:22:53 +0000 Subject: [PATCH 7/7] fix lint --- python/sglang/srt/server_args.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8175094f305..a337a5a8c29 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -269,8 +269,7 @@ def __post_init__(self): reserve_mem = 1024 * 20 self.mem_fraction_static = min( mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem, - (gpu_mem - reserve_mem) - / gpu_mem, + (gpu_mem - reserve_mem) / gpu_mem, ) # Set chunked prefill size, which depends on the gpu memory capacity