From c211e7b669c72a35dc8c128f2af20ac928f73280 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 24 Nov 2024 04:47:10 -0800 Subject: [PATCH] Simplify batch update (#2154) --- README.md | 2 +- docs/backend/backend.md | 1 - docs/index.rst | 2 +- docs/references/hyperparameter_tuning.md | 1 - python/sglang/srt/managers/schedule_batch.py | 4 +- python/sglang/srt/managers/scheduler.py | 81 ++++++++++---------- python/sglang/test/few_shot_gsm8k.py | 2 +- 7 files changed, 47 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index 2132ed86026..ebb7f1288f6 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ SGLang is a fast serving framework for large language models and vision language It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. The core features include: -- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ). +- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, overhead-free CPU scheduler, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (FP8/INT4/AWQ/GPTQ). - **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions. - **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models. - **Active Community**: SGLang is open-source and backed by an active community with industry adoption. diff --git a/docs/backend/backend.md b/docs/backend/backend.md index cce345e1037..a2995455f3d 100644 --- a/docs/backend/backend.md +++ b/docs/backend/backend.md @@ -79,7 +79,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096 ``` -- To enable the experimental overlapped scheduler, add `--enable-overlap-schedule`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currently. - To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently. - To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies. - To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. diff --git a/docs/index.rst b/docs/index.rst index e81cdd14981..873999d2587 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,7 +5,7 @@ SGLang is a fast serving framework for large language models and vision language It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. The core features include: -- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ). +- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, overhead-free CPU scheduler, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (FP8/INT4/AWQ/GPTQ). - **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions. - **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte) and reward models (Skywork), with easy extensibility for integrating new models. - **Active Community**: SGLang is open-source and backed by an active community with industry adoption. diff --git a/docs/references/hyperparameter_tuning.md b/docs/references/hyperparameter_tuning.md index 2729b968a23..92830f644e4 100644 --- a/docs/references/hyperparameter_tuning.md +++ b/docs/references/hyperparameter_tuning.md @@ -31,7 +31,6 @@ If you see out of memory (OOM) errors, you can try to tune the following paramet - You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding. ### Try Advanced Options -- To enable the experimental overlapped scheduler, add `--enable-overlap-schedule`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currently. - To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently. ### Tune `--schedule-policy` diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ad56be197e7..c9f0ea676f2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -467,6 +467,7 @@ class ScheduleBatch: extend_lens: List[int] = None extend_num_tokens: int = None decoding_reqs: List[Req] = None + extend_logprob_start_lens: List[int] = None # For encoder-decoder encoder_cached: Optional[List[bool]] = None @@ -722,7 +723,6 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): self.merge_batch(running_batch) self.input_ids = input_ids self.out_cache_loc = out_cache_loc - self.extend_num_tokens += running_bs # NOTE: prefix_indices is what has been cached, but we don't cache each decode step self.prefix_lens.extend( @@ -732,6 +732,8 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): ] ) self.extend_lens.extend([1] * running_bs) + self.extend_num_tokens += running_bs + # TODO (lianmin): Revisit this. It should be seq_len - 1 self.extend_logprob_start_lens.extend([0] * running_bs) def check_decode_mem(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index dbe18c129a2..5e5b4c68573 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,7 +13,6 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" -import dataclasses import logging import os import threading @@ -28,7 +27,7 @@ import zmq from sglang.global_config import global_config -from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -302,6 +301,9 @@ def __init__( ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio + # Tells whether the current running batch is full so that we can skip + # the check of whether to prefill new requests. + # This is an optimization to reduce the overhead of the prefill check. self.batch_is_full = False # Init watchdog thread @@ -721,40 +723,30 @@ def check_memory(self): def get_next_batch_to_run(self): # Merge the prefill batch into the running batch - if ( - self.last_batch - and not self.last_batch.forward_mode.is_decode() - and not self.last_batch.is_empty() - ): + if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.being_chunked_req: + # Move the chunked request out of the batch self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) self.tree_cache.cache_unfinished_req(self.being_chunked_req) - # Inflight request keeps its rid but will get a new req_pool_idx. + # Inflight request keeps its rid but will get a new req_pool_idx self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.batch_is_full = False + if not self.last_batch.is_empty(): if self.running_batch is None: self.running_batch = self.last_batch else: self.running_batch.merge_batch(self.last_batch) - # Prefill first + # Run prefill first if possible new_batch = self.get_new_batch_prefill() if new_batch is not None: return new_batch - # Check memory - if self.running_batch is None: - return - # Run decode - before_bs = self.running_batch.batch_size() - self.update_running_batch() - if not self.running_batch: - self.batch_is_full = False + if self.running_batch is None: return None - if before_bs != self.running_batch.batch_size(): - self.batch_is_full = False + self.running_batch = self.update_running_batch(self.running_batch) return self.running_batch def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: @@ -866,15 +858,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: return new_batch - def update_running_batch(self): + def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: """Update the current running decoding batch.""" global test_retract - batch = self.running_batch + + initial_bs = batch.batch_size() batch.filter_batch() if batch.is_empty(): - self.running_batch = None - return + self.batch_is_full = False + return None # Check if decode out of memory if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): @@ -900,11 +893,15 @@ def update_running_batch(self): jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): - self.running_batch = None - return + self.batch_is_full = False + return None + + if batch.batch_size() < initial_bs: + self.batch_is_full = False # Update batch tensors batch.prepare_for_decode(self.enable_overlap) + return batch def run_batch(self, batch: ScheduleBatch): """Run a batch.""" @@ -979,8 +976,10 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): if req.is_retracted: continue + if self.is_mixed_chunk and self.enable_overlap and req.finished(): + raise ValueError("Unhandled error!") + if req.is_being_chunked <= 0: - # Inflight reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -990,14 +989,16 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): elif not batch.decoding_reqs or req not in batch.decoding_reqs: self.tree_cache.cache_unfinished_req(req) - if req.grammar is not None: - req.grammar.accept_token(next_token_id) - if req.return_logprob: + # TODO (lianmin): need to think the case w/ mixed chunked prefill logprob_pt += self.add_logprob_return_values( i, req, logprob_pt, next_token_ids, logits_output ) + + if req.grammar is not None: + req.grammar.accept_token(next_token_id) else: + # Inflight reqs' prefill is not finished req.is_being_chunked -= 1 if batch.next_batch_sampling_info: @@ -1015,18 +1016,18 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): continue req.embedding = embeddings[i] - if req.is_being_chunked > 0: - req.is_being_chunked -= 1 - else: - # Inflight reqs' prefill is not finished - # dummy output token for embedding models + if req.is_being_chunked <= 0: + # Dummy output token for embedding models req.output_ids.append(0) req.check_finished() - if req.finished(): - self.tree_cache.cache_finished_req(req) + if req.finished(): + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_unfinished_req(req) else: - self.tree_cache.cache_unfinished_req(req) + # Inflight reqs' prefill is not finished + req.is_being_chunked -= 1 self.stream_output(batch.reqs) @@ -1061,9 +1062,6 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.output_ids.append(next_token_id) req.check_finished() - if req.grammar is not None: - req.grammar.accept_token(next_token_id) - if req.finished(): self.tree_cache.cache_finished_req(req) @@ -1074,6 +1072,9 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): if req.top_logprobs_num > 0: req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + if req.grammar is not None: + req.grammar.accept_token(next_token_id) + if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() torch.cuda.current_stream().synchronize() diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py index 8e6572da61a..9657e730084 100644 --- a/python/sglang/test/few_shot_gsm8k.py +++ b/python/sglang/test/few_shot_gsm8k.py @@ -135,7 +135,7 @@ def few_shot_gsm8k(s, question): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num-shots", type=int, default=5) - parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--data-path", type=str) parser.add_argument("--num-questions", type=int, default=200) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--parallel", type=int, default=128)