-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Support EAGLE draft extend CUDA graph #6606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
29fd0b3
50a96f5
fe3b0aa
19aeb04
ed5c62d
15f9897
e9814b8
a8172b4
40b916b
0e9114f
9055a49
e9a9fd2
e610802
6707ea9
cd40175
0aac2ac
130bfc3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,250 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import bisect | ||
| from typing import TYPE_CHECKING, Callable | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.model_executor.cuda_graph_runner import ( | ||
| CudaGraphRunner, | ||
| LogitsProcessorOutput, | ||
| get_batch_sizes_to_capture, | ||
| get_global_graph_memory_pool, | ||
| set_global_graph_memory_pool, | ||
| set_torch_compile_config, | ||
| ) | ||
| from sglang.srt.model_executor.forward_batch_info import ( | ||
| CaptureHiddenMode, | ||
| ForwardBatch, | ||
| ForwardMode, | ||
| ) | ||
| from sglang.srt.speculative.eagle_utils import EagleDraftInput | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sglang.srt.speculative.eagle_worker import EAGLEWorker | ||
|
|
||
|
|
||
| class EAGLEDraftExtendCudaGraphRunner: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW when enabling extended CUDA Graph, we should also adjust
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be OOM. |
||
| def __init__(self, eagle_worker: EAGLEWorker): | ||
| # Parse args | ||
| self.eagle_worker = eagle_worker | ||
| self.model_runner = model_runner = eagle_worker.model_runner | ||
| self.graphs = {} | ||
| self.output_buffers = {} | ||
| self.enable_torch_compile = model_runner.server_args.enable_torch_compile | ||
| self.disable_padding = model_runner.server_args.disable_cuda_graph_padding | ||
| self.tp_size = self.model_runner.tp_size | ||
| self.dp_size = model_runner.server_args.dp_size | ||
| self.speculative_num_steps = model_runner.server_args.speculative_num_steps | ||
| self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) | ||
| self.padded_static_len = -1 | ||
|
|
||
| # Attention backend | ||
| self.num_tokens_per_bs = self.speculative_num_steps + 1 | ||
| self.max_bs = max(self.capture_bs) | ||
| self.max_num_token = self.max_bs * self.num_tokens_per_bs | ||
|
|
||
| self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state( | ||
| self.max_num_token | ||
| ) | ||
| self.seq_len_fill_value = ( | ||
| self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value() | ||
| ) | ||
| self.seq_lens_cpu = torch.full( | ||
| (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 | ||
| ) | ||
|
|
||
| if self.enable_torch_compile: | ||
| set_torch_compile_config() | ||
|
|
||
| # Graph inputs | ||
| with torch.device("cuda"): | ||
| self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) | ||
| self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) | ||
| self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64) | ||
| self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) | ||
|
|
||
| if self.eagle_worker.speculative_algorithm.is_eagle3(): | ||
| self.hidden_states = torch.zeros( | ||
| ( | ||
| self.max_num_token, | ||
| self.model_runner.model_config.hidden_size * 3, | ||
| ), | ||
| dtype=self.model_runner.dtype, | ||
| ) | ||
| else: | ||
| self.hidden_states = torch.zeros( | ||
| (self.max_num_token, self.model_runner.model_config.hidden_size), | ||
| dtype=self.model_runner.dtype, | ||
| ) | ||
|
|
||
| self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) | ||
| self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) | ||
| self.accept_length = torch.ones((self.max_bs,), dtype=torch.int32) | ||
|
|
||
| # Capture | ||
| try: | ||
| self.capture() | ||
| except RuntimeError as e: | ||
| raise Exception( | ||
| f"Capture CUDA graph failed: {e}\n" | ||
| "Possible solutions:\n" | ||
| "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" | ||
| "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" | ||
| "3. disable torch compile by not using --enable-torch-compile\n" | ||
| "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" | ||
| "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" | ||
| ) | ||
|
|
||
| def can_run(self, forward_batch: ForwardBatch): | ||
| batch_size = forward_batch.seq_lens.numel() | ||
|
|
||
| is_bs_supported = ( | ||
| batch_size in self.graphs | ||
| if self.disable_padding | ||
| else batch_size <= self.max_bs | ||
| ) | ||
|
|
||
| return is_bs_supported | ||
|
|
||
| def capture(self): | ||
| CudaGraphRunner.capture(self) | ||
|
|
||
| def capture_one_batch_size(self, bs: int, forward: Callable): | ||
| graph = torch.cuda.CUDAGraph() | ||
| stream = self.stream | ||
| num_tokens = bs * self.num_tokens_per_bs | ||
|
|
||
| # Graph inputs | ||
| input_ids = self.input_ids[:num_tokens] | ||
| req_pool_indices = self.req_pool_indices[:bs] | ||
| seq_lens = self.seq_lens[:bs] | ||
| extend_seq_lens = self.extend_seq_lens[:bs] | ||
| accept_length = self.accept_length[:bs] | ||
| out_cache_loc = self.out_cache_loc[:num_tokens] | ||
| positions = self.positions[:num_tokens] | ||
| hidden_states = self.hidden_states[:num_tokens] | ||
|
|
||
| spec_info = EagleDraftInput( | ||
| hidden_states=hidden_states, | ||
| accept_length=accept_length, | ||
| ) | ||
| spec_info.positions = None | ||
|
|
||
| # Forward batch | ||
| forward_batch = ForwardBatch( | ||
| forward_mode=ForwardMode.DRAFT_EXTEND, | ||
| batch_size=bs, | ||
| input_ids=input_ids, | ||
| req_pool_indices=req_pool_indices, | ||
| seq_lens=seq_lens, | ||
| req_to_token_pool=self.model_runner.req_to_token_pool, | ||
| token_to_kv_pool=self.model_runner.token_to_kv_pool, | ||
| out_cache_loc=out_cache_loc, | ||
| seq_lens_sum=seq_lens.sum(), | ||
| return_logprob=False, | ||
| positions=positions, | ||
| spec_algorithm=self.model_runner.spec_algorithm, | ||
| spec_info=spec_info, | ||
| capture_hidden_mode=CaptureHiddenMode.LAST, | ||
| attn_backend=self.eagle_worker.draft_extend_attn_backend, | ||
| extend_seq_lens=extend_seq_lens, | ||
| padded_static_len=self.padded_static_len, | ||
| ) | ||
|
|
||
| self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_capture_cuda_graph( | ||
| bs=bs, | ||
| num_tokens=num_tokens, | ||
| req_pool_indices=req_pool_indices, | ||
| seq_lens=seq_lens, | ||
| encoder_lens=None, | ||
| forward_mode=ForwardMode.DRAFT_EXTEND, | ||
| spec_info=spec_info, | ||
| ) | ||
|
|
||
| # Run and capture | ||
| def run_once(): | ||
| # Backup two fields, which will be modified in-place in `draft_forward`. | ||
| output_cache_loc_backup = forward_batch.out_cache_loc | ||
| hidden_states_backup = forward_batch.spec_info.hidden_states | ||
|
|
||
| ret = self.eagle_worker.draft_model_runner.model.forward( | ||
| forward_batch.input_ids, | ||
| forward_batch.positions, | ||
| forward_batch, | ||
| ) | ||
|
|
||
| forward_batch.out_cache_loc = output_cache_loc_backup | ||
| forward_batch.spec_info.hidden_states = hidden_states_backup | ||
| return ret | ||
|
|
||
| for _ in range(2): | ||
| torch.cuda.synchronize() | ||
| self.model_runner.tp_group.barrier() | ||
|
|
||
| run_once() | ||
|
|
||
| with torch.cuda.graph( | ||
| graph, pool=get_global_graph_memory_pool(), stream=stream | ||
| ): | ||
| out = run_once() | ||
|
|
||
| set_global_graph_memory_pool(graph.pool()) | ||
| return graph, out | ||
|
|
||
| def replay(self, forward_batch: ForwardBatch): | ||
| assert forward_batch.out_cache_loc is not None | ||
| # batch_size and num_seqs can be different in case there are finished examples | ||
| # in the batch, which will not be counted as num_seqs | ||
| raw_bs = forward_batch.seq_lens.numel() | ||
| num_tokens = forward_batch.input_ids.numel() | ||
| assert raw_bs * self.num_tokens_per_bs == num_tokens | ||
|
|
||
| index = bisect.bisect_left(self.capture_bs, raw_bs) | ||
| bs = self.capture_bs[index] | ||
| if bs != raw_bs: | ||
| self.accept_length.fill_(1) | ||
| self.out_cache_loc.zero_() | ||
|
|
||
| # Common inputs | ||
| self.input_ids[:num_tokens].copy_(forward_batch.input_ids) | ||
| self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens[:raw_bs]) | ||
| self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens[:raw_bs]) | ||
| self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) | ||
| self.positions[:num_tokens].copy_(forward_batch.positions) | ||
| self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) | ||
| self.accept_length[:raw_bs].copy_( | ||
| forward_batch.spec_info.accept_length[:raw_bs] | ||
| ) | ||
| self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) | ||
|
|
||
| if forward_batch.seq_lens_cpu is not None: | ||
| if bs != raw_bs: | ||
| self.seq_lens_cpu.fill_(1) | ||
| self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) | ||
|
|
||
| forward_batch.spec_info.positions = None | ||
| if bs != raw_bs: | ||
| forward_batch.spec_info.accept_length = self.accept_length[:bs] | ||
|
|
||
| self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( | ||
| bs=bs, | ||
| req_pool_indices=self.req_pool_indices, | ||
| seq_lens=self.seq_lens, | ||
| seq_lens_sum=sum(forward_batch.seq_lens), | ||
| encoder_lens=None, | ||
| forward_mode=ForwardMode.DRAFT_EXTEND, | ||
| spec_info=forward_batch.spec_info, | ||
| seq_lens_cpu=self.seq_lens_cpu, | ||
| ) | ||
|
|
||
| # Replay | ||
| self.graphs[bs].replay() | ||
| out = self.output_buffers[bs] | ||
| if bs != raw_bs: | ||
| forward_batch.spec_info.accept_length = self.accept_length[:raw_bs] | ||
| out = LogitsProcessorOutput( | ||
| next_token_logits=out.next_token_logits[:raw_bs], | ||
| hidden_states=out.hidden_states[:raw_bs], | ||
| ) | ||
| return out | ||
Uh oh!
There was an error while loading. Please reload this page.