Skip to content

Commit

Permalink
[Frontend] Fix request length check and add option to disallow auto t…
Browse files Browse the repository at this point in the history
…runcation in scheduler (#2876)
  • Loading branch information
CatherineSue authored Jan 16, 2025
1 parent 0427416 commit a8ccacc
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 17 deletions.
32 changes: 17 additions & 15 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Expand Down Expand Up @@ -690,14 +691,16 @@ def handle_generate_request(
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1

# Truncate prompts that are too long
if len(req.origin_input_ids) > self.max_req_input_len:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)

if error_msg:
self.waiting_queue.append(req)
return

req.sampling_params.max_new_tokens = min(
(
Expand Down Expand Up @@ -745,13 +748,12 @@ def handle_embedding_request(
)
req.tokenizer = self.tokenizer

# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
# Validate prompts length
validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)

self.waiting_queue.append(req)

Expand Down
20 changes: 18 additions & 2 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,28 @@ async def _tokenize_one_request(
SessionParams(**obj.session_params) if obj.session_params else None
)

if obj.input_ids is not None and len(input_ids) >= self.context_len:
input_token_num = len(input_ids) if input_ids is not None else 0
if input_token_num >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)

if (
obj.sampling_params.get("max_new_tokens") is not None
and obj.sampling_params.get("max_new_tokens") + input_token_num
>= self.context_len
):
raise ValueError(
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of "
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
f"completion. Please reduce the number of tokens in the input "
f"messages or the completion to fit within the limit."
)

# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)
Expand Down
41 changes: 41 additions & 0 deletions python/sglang/srt/managers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import Optional

from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req

logger = logging.getLogger(__name__)


def validate_input_length(
req: Req, max_req_input_len: int, allow_auto_truncate: bool
) -> Optional[str]:
"""Validate and potentially truncate input length.
Args:
req: The request containing input_ids to validate
max_req_input_len: Maximum allowed input length
allow_auto_truncate: Whether to truncate long inputs
Returns:
Error message if validation fails, None if successful
"""
if len(req.origin_input_ids) >= max_req_input_len:
if allow_auto_truncate:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
return None
else:
error_msg = (
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
f"the maximum allowed length ({max_req_input_len} tokens). "
f"Use a shorter input or enable --allow-auto-truncate."
)
logger.error(error_msg)
req.finished_reason = FINISH_ABORT(error_msg)
return error_msg

return None
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class ServerArgs:
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
allow_auto_truncate: bool = False

def __post_init__(self):
# Set missing default values
Expand Down Expand Up @@ -859,6 +860,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
)
parser.add_argument(
"--allow-auto-truncate",
action="store_true",
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_release_memory_occupation.py",
"test_request_length_validation.py",
"test_retract_decode.py",
"test_server_args.py",
"test_session_control.py",
Expand Down
71 changes: 71 additions & 0 deletions test/srt/test_request_length_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import unittest

import openai

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestRequestLengthValidation(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"

# Start server with auto truncate disabled
cls.process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=("--max-total-tokens", "1000", "--context-length", "100"),
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_input_length_validation(self):
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")

long_text = "hello " * 100 # Will tokenize to more than context length

with self.assertRaises(openai.BadRequestError) as cm:
client.chat.completions.create(
model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
messages=[
{"role": "user", "content": long_text},
],
temperature=0,
)

self.assertIn("is longer than the model's context length", str(cm.exception))

def test_max_tokens_validation(self):
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")

long_text = "hello "

with self.assertRaises(openai.BadRequestError) as cm:
client.chat.completions.create(
model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
messages=[
{"role": "user", "content": long_text},
],
temperature=0,
max_tokens=500,
)

self.assertIn(
"Requested token count exceeds the model's maximum context",
str(cm.exception),
)


if __name__ == "__main__":
unittest.main()

0 comments on commit a8ccacc

Please sign in to comment.