Skip to content

Commit

Permalink
Simplify batch update (#2154)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 24, 2024
1 parent d90c3d6 commit c211e7b
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ SGLang is a fast serving framework for large language models and vision language
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
The core features include:

- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, overhead-free CPU scheduler, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (FP8/INT4/AWQ/GPTQ).
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
Expand Down
1 change: 0 additions & 1 deletion docs/backend/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
```
- To enable the experimental overlapped scheduler, add `--enable-overlap-schedule`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currently.
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SGLang is a fast serving framework for large language models and vision language
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
The core features include:

- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, overhead-free CPU scheduler, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (FP8/INT4/AWQ/GPTQ).
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte) and reward models (Skywork), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
Expand Down
1 change: 0 additions & 1 deletion docs/references/hyperparameter_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ If you see out of memory (OOM) errors, you can try to tune the following paramet
- You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.

### Try Advanced Options
- To enable the experimental overlapped scheduler, add `--enable-overlap-schedule`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currently.
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.

### Tune `--schedule-policy`
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ class ScheduleBatch:
extend_lens: List[int] = None
extend_num_tokens: int = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None

# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
Expand Down Expand Up @@ -722,7 +723,6 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens += running_bs

# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
Expand All @@ -732,6 +732,8 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
]
)
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)

def check_decode_mem(self):
Expand Down
81 changes: 41 additions & 40 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker."""

import dataclasses
import logging
import os
import threading
Expand All @@ -28,7 +27,7 @@
import zmq

from sglang.global_config import global_config
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
Expand Down Expand Up @@ -302,6 +301,9 @@ def __init__(
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio

# Tells whether the current running batch is full so that we can skip
# the check of whether to prefill new requests.
# This is an optimization to reduce the overhead of the prefill check.
self.batch_is_full = False

# Init watchdog thread
Expand Down Expand Up @@ -721,40 +723,30 @@ def check_memory(self):

def get_next_batch_to_run(self):
# Merge the prefill batch into the running batch
if (
self.last_batch
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.being_chunked_req:
# Move the chunked request out of the batch
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
# Inflight request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False

if not self.last_batch.is_empty():
if self.running_batch is None:
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)

# Prefill first
# Run prefill first if possible
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
return new_batch

# Check memory
if self.running_batch is None:
return

# Run decode
before_bs = self.running_batch.batch_size()
self.update_running_batch()
if not self.running_batch:
self.batch_is_full = False
if self.running_batch is None:
return None
if before_bs != self.running_batch.batch_size():
self.batch_is_full = False
self.running_batch = self.update_running_batch(self.running_batch)
return self.running_batch

def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Expand Down Expand Up @@ -866,15 +858,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:

return new_batch

def update_running_batch(self):
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
"""Update the current running decoding batch."""
global test_retract
batch = self.running_batch

initial_bs = batch.batch_size()

batch.filter_batch()
if batch.is_empty():
self.running_batch = None
return
self.batch_is_full = False
return None

# Check if decode out of memory
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Expand All @@ -900,11 +893,15 @@ def update_running_batch(self):
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
self.running_batch = None
return
self.batch_is_full = False
return None

if batch.batch_size() < initial_bs:
self.batch_is_full = False

# Update batch tensors
batch.prepare_for_decode(self.enable_overlap)
return batch

def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
Expand Down Expand Up @@ -979,8 +976,10 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if req.is_retracted:
continue

if self.is_mixed_chunk and self.enable_overlap and req.finished():
raise ValueError("Unhandled error!")

if req.is_being_chunked <= 0:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
Expand All @@ -990,14 +989,16 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req)

if req.grammar is not None:
req.grammar.accept_token(next_token_id)

if req.return_logprob:
# TODO (lianmin): need to think the case w/ mixed chunked prefill
logprob_pt += self.add_logprob_return_values(
i, req, logprob_pt, next_token_ids, logits_output
)

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
else:
# Inflight reqs' prefill is not finished
req.is_being_chunked -= 1

if batch.next_batch_sampling_info:
Expand All @@ -1015,18 +1016,18 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
continue

req.embedding = embeddings[i]
if req.is_being_chunked > 0:
req.is_being_chunked -= 1
else:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
if req.is_being_chunked <= 0:
# Dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()

if req.finished():
self.tree_cache.cache_finished_req(req)
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
# Inflight reqs' prefill is not finished
req.is_being_chunked -= 1

self.stream_output(batch.reqs)

Expand Down Expand Up @@ -1061,9 +1062,6 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
req.output_ids.append(next_token_id)
req.check_finished()

if req.grammar is not None:
req.grammar.accept_token(next_token_id)

if req.finished():
self.tree_cache.cache_finished_req(req)

Expand All @@ -1074,6 +1072,9 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

if req.grammar is not None:
req.grammar.accept_token(next_token_id)

if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/test/few_shot_gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def few_shot_gsm8k(s, question):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--data-path", type=str)
parser.add_argument("--num-questions", type=int, default=200)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--parallel", type=int, default=128)
Expand Down

0 comments on commit c211e7b

Please sign in to comment.