Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,7 @@ async def health_generate(request: Request) -> Response:
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"{HEALTH_CHECK_RID_PREFIX}_{time.time()}"

if _global_state.tokenizer_manager.is_image_gen:
gri = _global_state.tokenizer_manager.get_image_gen_health_check_request(
rid, sampling_params
)
elif _global_state.tokenizer_manager.is_generation:
if _global_state.tokenizer_manager.is_generation:
gri = GenerateReqInput(
rid=rid,
input_ids=[0],
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import copy
import uuid
from abc import ABC
from collections import Counter
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
Expand Down Expand Up @@ -58,6 +59,15 @@ def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid

def _validate_rid_uniqueness(self):
"""Validate that request IDs within a batch are unique."""
if isinstance(self.rid, list) and len(set(self.rid)) != len(self.rid):
counts = Counter(self.rid)
duplicates = [rid for rid, count in counts.items() if count > 1]
raise ValueError(
f"Duplicate request IDs detected within the request: {duplicates}"
)


@dataclass
class BaseBatchReq(ABC):
Expand Down Expand Up @@ -276,6 +286,8 @@ def normalize_batch_and_arguments(self):
else:
self._normalize_batch_inputs()

self._validate_rid_uniqueness()

def _validate_inputs(self):
"""Validate that the input configuration is valid."""
if (
Expand Down Expand Up @@ -853,6 +865,8 @@ def normalize_batch_and_arguments(self):

self._normalize_lora_paths(self.batch_size)

self._validate_rid_uniqueness()

def _normalize_lora_paths(self, num):
"""Normalize LoRA paths for batch processing."""
if self.lora_path is not None:
Expand Down
46 changes: 19 additions & 27 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class ReqState:
finished: bool
event: asyncio.Event
obj: Union[GenerateReqInput, EmbeddingReqInput]

# For performance metrics
time_stats: APIServerReqTimeStats
last_completion_tokens: int = 1
ttft_observed: bool = False
Expand Down Expand Up @@ -216,9 +218,6 @@ def __init__(
# Init metric collector and watchdog
self.init_metric_collector_watchdog()

if self.enable_metrics:
start_cpu_monitor_thread("tokenizer")

# Init request dispatcher
self.init_request_dispatcher()

Expand All @@ -231,7 +230,6 @@ def init_model_config(self):
self.served_model_name = server_args.served_model_name
self.model_config = model_config_class.from_server_args(server_args)
self.is_generation = self.model_config.is_generation
self.is_image_gen = getattr(self.model_config, "is_image_gen", False)
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
self.max_req_input_len = None # Will be set later in engine.py
Expand Down Expand Up @@ -339,10 +337,6 @@ def init_running_status(self):
self.gracefully_exit = False
self.last_receive_tstamp = real_time()

# For load balancing
self.current_load = 0
self.current_load_lock = asyncio.Lock()

# Session
self.session_futures = {} # session_id -> asyncio event

Expand Down Expand Up @@ -441,6 +435,8 @@ def init_metric_collector_watchdog(self):
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
)

start_cpu_monitor_thread("tokenizer")

if self.server_args.gc_warning_threshold_secs > 0.0:
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
self.soft_watchdog = Watchdog.create(
Expand Down Expand Up @@ -489,7 +485,7 @@ async def generate_request(
# Normalize the request
obj.normalize_batch_and_arguments()
self._set_default_priority(obj)
self._validate_rid(obj)
self._validate_rid_not_in_flight(obj)

if isinstance(obj, GenerateReqInput) and obj.routed_dp_rank is not None:
dp_size = self.server_args.dp_size
Expand Down Expand Up @@ -771,20 +767,16 @@ async def _tokenize_one_request(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
)

def _validate_rid(self, obj: Union[GenerateReqInput, EmbeddingReqInput]) -> None:
"""Validate the request ID (rid) uniqueness."""
rid = obj.rid
if rid is None:
def _validate_rid_not_in_flight(
self, obj: Union[GenerateReqInput, EmbeddingReqInput]
) -> None:
"""Validate that request IDs are not already in flight."""
if obj.rid is None:
return
ids = rid if isinstance(rid, list) else [rid]
if len(ids) != len(set(ids)):
raise ValueError(
f"Duplicate request IDs detected within the request: {ids}"
)

for i in ids:
if i in self.rid_to_state:
raise ValueError(f"Duplicate request ID detected: {i}")
rids = obj.rid if isinstance(obj.rid, list) else [obj.rid]
conflicts = set(rids) & self.rid_to_state.keys()
if conflicts:
raise ValueError(f"Duplicate request IDs detected: {list(conflicts)}")

def _validate_one_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
Expand Down Expand Up @@ -2312,14 +2304,14 @@ def _req_stats_init(

external_trace_header = None
if self.server_args.enable_trace:
if request:
external_trace_header = extract_trace_headers(request.headers)
obj.external_trace_header = external_trace_header
elif obj.external_trace_header:
# When the request comes form the rust grpc server or Engine there isn't a
if obj.external_trace_header:
# When the request comes from the rust grpc server or Engine there isn't a
# real request object but we still need to propagate the trace context from
# the trace context that is explicitly passed in
external_trace_header = obj.external_trace_header
elif request:
external_trace_header = extract_trace_headers(request.headers)
obj.external_trace_header = external_trace_header

if not hasattr(obj, "is_single") or obj.is_single:
time_stats = APIServerReqTimeStats(disagg_mode=self.disaggregation_mode)
Expand Down
Loading