Skip to content

Commit 071a1f5

Browse files
authored
[Minor] clean up multimodal processor and tokenizer manager (#7624)
1 parent 7c0db3a commit 071a1f5

File tree

9 files changed

+141
-159
lines changed

9 files changed

+141
-159
lines changed

python/sglang/srt/entrypoints/http_server.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
353353
obj = GenerateReqInput(
354354
input_embeds=input_embeds,
355355
sampling_params={
356-
"repetition_penalty": 1.2,
357-
"temperature": 0.2,
356+
"temperature": 0.0,
358357
"max_new_tokens": 512,
359358
},
360359
)
@@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
393392
return _create_error_response(e)
394393

395394

396-
@app.api_route(
397-
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
398-
)
399-
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
400-
"""Endpoint for reranking documents based on query relevance."""
401-
return await raw_request.app.state.openai_serving_rerank.handle_request(
402-
request, raw_request
403-
)
404-
405-
406395
@app.api_route("/flush_cache", methods=["GET", "POST"])
407396
async def flush_cache():
408397
"""Flush the radix cache."""
@@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
841830
)
842831

843832

833+
@app.api_route(
834+
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
835+
)
836+
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
837+
"""Endpoint for reranking documents based on query relevance."""
838+
return await raw_request.app.state.openai_serving_rerank.handle_request(
839+
request, raw_request
840+
)
841+
842+
844843
def _create_error_response(e):
845844
return ORJSONResponse(
846845
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST

python/sglang/srt/managers/io_struct.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@
2222
from enum import Enum
2323
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
2424

25+
from sglang.srt.managers.schedule_batch import BaseFinishReason
2526
from sglang.srt.multimodal.mm_utils import has_valid_data
27+
from sglang.srt.sampling.sampling_params import SamplingParams
2628

27-
# handle serialization of Image for pydantic
29+
# Handle serialization of Image for pydantic
2830
if TYPE_CHECKING:
2931
from PIL.Image import Image
3032
else:
3133
Image = Any
3234

33-
from sglang.srt.managers.schedule_batch import BaseFinishReason
34-
from sglang.srt.sampling.sampling_params import SamplingParams
35-
3635

3736
@dataclass
3837
class SessionParams:
@@ -182,6 +181,7 @@ def _handle_parallel_sampling(self):
182181
# Determine parallel sample count
183182
if self.sampling_params is None:
184183
self.parallel_sample_num = 1
184+
return
185185
elif isinstance(self.sampling_params, dict):
186186
self.parallel_sample_num = self.sampling_params.get("n", 1)
187187
else: # isinstance(self.sampling_params, list):

python/sglang/srt/managers/multimodal_processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def get_dummy_processor():
2525
return DummyMultimodalProcessor()
2626

2727

28-
@lru_cache()
2928
def import_processors():
3029
package_name = "sglang.srt.multimodal.processors"
3130
package = importlib.import_module(package_name)

python/sglang/srt/managers/schedule_batch.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -180,46 +180,48 @@ class Modality(Enum):
180180
@dataclasses.dataclass
181181
class MultimodalDataItem:
182182
"""
183-
A single multimodal data, from a single image/video/audio or others
183+
A single multimodal data, from a single image/video/audio or others.
184+
185+
We put the common fields first and the model-specific fields last.
184186
"""
185187

186188
modality: Modality
187-
188189
hash: int = None
189190
pad_value: int = None
190-
191-
aspect_ratio_id: Optional[List[torch.Tensor]] = None
192-
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
193-
194191
image_sizes: Tuple[int, int] = None
195192
image_offsets: Optional[list] = None
196193

197194
# the real data, pixel_values or audio_features
198195
# data: Union[List[torch.Tensor], List[np.ndarray]]
199196
pixel_values: Union[torch.Tensor, np.ndarray] = None
197+
audio_features: Union[torch.Tensor, np.ndarray] = None
198+
audio_feature_lens: Optional[List[torch.Tensor]] = None
199+
audio_offsets: Optional[List[Tuple[int, int]]] = None
200+
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
201+
202+
# For qwen-vl
200203
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
201-
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
204+
second_per_grid_ts: Optional[List[torch.Tensor]] = None
202205

206+
# For deepseek-vl
203207
image_emb_mask: Optional[torch.Tensor] = None
204208
image_spatial_crop: Optional[torch.Tensor] = None
205-
second_per_grid_ts: Optional[List[torch.Tensor]] = None
206209

210+
# For minicpmv
207211
# [num_images, (n, w, h)]
208212
tgt_size: Tuple[int, int] = None
209213

210-
# kimi-vl related
211-
image_grid_hws: Optional[List[torch.Tensor]] = None
214+
# For mllama
215+
aspect_ratio_id: Optional[List[torch.Tensor]] = None
216+
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
212217

213-
audio_features: Union[torch.Tensor, np.ndarray] = None
214-
audio_feature_lens: Optional[List[torch.Tensor]] = None
215-
audio_offsets: Optional[List[Tuple[int, int]]] = None
218+
# For kimi-vl
219+
image_grid_hws: Optional[List[torch.Tensor]] = None
216220

217-
# gemma3n related
221+
# For gemma3n
218222
input_features: Optional[torch.Tensor] = None
219223
input_features_mask: Optional[torch.Tensor] = None
220224

221-
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
222-
223225
@staticmethod
224226
def is_empty_list(l):
225227
if l is None:
@@ -339,10 +341,6 @@ class MultimodalInputs:
339341
image_pad_len: Optional[list] = None
340342
num_image_tokens: Optional[int] = None
341343

342-
# QWen2-VL related
343-
mrope_positions: Optional[torch.Tensor] = None
344-
mrope_position_delta: Optional[torch.Tensor] = None
345-
346344
# image
347345
im_token_id: Optional[int] = None
348346
im_start_id: Optional[int] = None
@@ -358,6 +356,10 @@ class MultimodalInputs:
358356
audio_start_id: Optional[int] = None
359357
audio_end_id: Optional[int] = None
360358

359+
# QWen2-VL related
360+
mrope_positions: Optional[torch.Tensor] = None
361+
mrope_position_delta: Optional[torch.Tensor] = None
362+
361363
@staticmethod
362364
def from_dict(obj: dict):
363365
ret = MultimodalInputs(

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 54 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ class ReqState:
150150

151151
# For streaming output
152152
last_output_offset: int = 0
153+
153154
# For incremental state update.
155+
# TODO(lianmin): do not initialize some lists if not needed.
154156
text: str = ""
155157
output_ids: List[int] = dataclasses.field(default_factory=list)
156158
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
@@ -199,7 +201,6 @@ def __init__(
199201
self.model_path = server_args.model_path
200202
self.served_model_name = server_args.served_model_name
201203
self.model_config = ModelConfig.from_server_args(server_args)
202-
203204
self.is_generation = self.model_config.is_generation
204205
self.is_image_gen = self.model_config.is_image_gen
205206
self.context_len = self.model_config.context_len
@@ -251,19 +252,36 @@ def __init__(
251252
self.dump_requests_threshold = 1000
252253
self.dump_request_list: List[Tuple] = []
253254
self.log_request_metadata = self.get_log_request_metadata()
255+
self.asyncio_tasks = set()
256+
self.session_futures = {} # session_id -> asyncio event
257+
self.max_req_input_len = None
254258

255259
# The event to notify the weight sync is finished.
256260
self.model_update_lock = RWLock()
257261
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
258262
None
259263
)
260-
self.asyncio_tasks = set()
261264

262-
# For session info
263-
self.session_futures = {} # session_id -> asyncio event
265+
# For pd disaggregtion
266+
self.disaggregation_mode = DisaggregationMode(
267+
self.server_args.disaggregation_mode
268+
)
269+
self.transfer_backend = TransferBackend(
270+
self.server_args.disaggregation_transfer_backend
271+
)
272+
# Start kv boostrap server on prefill
273+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
274+
# only start bootstrap server on prefill tm
275+
kv_bootstrap_server_class = get_kv_class(
276+
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
277+
)
278+
self.bootstrap_server = kv_bootstrap_server_class(
279+
self.server_args.disaggregation_bootstrap_port
280+
)
264281

265-
# Set after scheduler is initialized
266-
self.max_req_input_len = None
282+
# For load balancing
283+
self.current_load = 0
284+
self.current_load_lock = asyncio.Lock()
267285

268286
# Metrics
269287
if self.enable_metrics:
@@ -393,66 +411,29 @@ def __init__(
393411
]
394412
)
395413

396-
# For pd disaggregtion
397-
self.disaggregation_mode = DisaggregationMode(
398-
self.server_args.disaggregation_mode
399-
)
400-
self.transfer_backend = TransferBackend(
401-
self.server_args.disaggregation_transfer_backend
402-
)
403-
# Start kv boostrap server on prefill
404-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
405-
# only start bootstrap server on prefill tm
406-
kv_bootstrap_server_class = get_kv_class(
407-
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
408-
)
409-
self.bootstrap_server = kv_bootstrap_server_class(
410-
self.server_args.disaggregation_bootstrap_port
411-
)
412-
413-
self.current_load = 0
414-
self.current_load_lock = asyncio.Lock()
415-
416414
async def generate_request(
417415
self,
418416
obj: Union[GenerateReqInput, EmbeddingReqInput],
419417
request: Optional[fastapi.Request] = None,
420418
):
421419
created_time = time.time()
422-
423420
self.auto_create_handle_loop()
421+
obj.normalize_batch_and_arguments()
424422

425423
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
426424
raise ValueError(
427425
"This model does not appear to be an embedding model by default. "
428426
"Please add `--is-embedding` when launching the server or try another model."
429427
)
430428

431-
obj.normalize_batch_and_arguments()
432-
433-
if isinstance(obj, GenerateReqInput):
434-
return_hidden_states = obj.return_hidden_states
435-
has_return_hidden_states = return_hidden_states == True or (
436-
isinstance(return_hidden_states, list) and any(return_hidden_states)
437-
)
438-
if (
439-
not self.server_args.enable_return_hidden_states
440-
and has_return_hidden_states
441-
):
442-
raise ValueError(
443-
"return_hidden_states=True requires the server to be started "
444-
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
445-
)
446-
447429
if self.log_requests:
448430
max_length, skip_names, _ = self.log_request_metadata
449431
logger.info(
450432
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
451433
)
452434

453435
async with self.model_update_lock.reader_lock:
454-
is_single = obj.is_single
455-
if is_single:
436+
if obj.is_single:
456437
tokenized_obj = await self._tokenize_one_request(obj)
457438
state = self._send_one_request(obj, tokenized_obj, created_time)
458439
async for response in self._wait_one_response(obj, state, request):
@@ -514,12 +495,12 @@ async def _tokenize_one_request(
514495
else:
515496
image_inputs: Optional[Dict] = None
516497

517-
self._validate_token_len(obj, input_ids)
498+
self._validate_one_request(obj, input_ids)
518499
return self._create_tokenized_object(
519500
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
520501
)
521502

522-
def _validate_token_len(
503+
def _validate_one_request(
523504
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
524505
) -> None:
525506
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
@@ -548,6 +529,24 @@ def _validate_token_len(
548529
)
549530
raise ValueError(error_msg)
550531

532+
if isinstance(obj, GenerateReqInput):
533+
if (
534+
obj.return_hidden_states
535+
and not self.server_args.enable_return_hidden_states
536+
):
537+
raise ValueError(
538+
"The server is not configured to return the hidden states. "
539+
"Please set `--enable-return-hidden-states` to enable this feature."
540+
)
541+
if (
542+
obj.custom_logit_processor
543+
and not self.server_args.enable_custom_logit_processor
544+
):
545+
raise ValueError(
546+
"The server is not configured to enable custom logit processor. "
547+
"Please set `--enable-custom-logits-processor` to enable this feature."
548+
)
549+
551550
def _create_tokenized_object(
552551
self,
553552
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -558,24 +557,6 @@ def _create_tokenized_object(
558557
token_type_ids: Optional[List[int]] = None,
559558
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
560559
"""Create a tokenized request object from common parameters."""
561-
562-
if self.is_generation:
563-
return_logprob = obj.return_logprob
564-
logprob_start_len = obj.logprob_start_len
565-
top_logprobs_num = obj.top_logprobs_num
566-
token_ids_logprob = obj.token_ids_logprob
567-
session_params = (
568-
SessionParams(**obj.session_params) if obj.session_params else None
569-
)
570-
if (
571-
obj.custom_logit_processor
572-
and not self.server_args.enable_custom_logit_processor
573-
):
574-
raise ValueError(
575-
"The server is not configured to enable custom logit processor. "
576-
"Please set `--enable-custom-logits-processor` to enable this feature."
577-
)
578-
579560
# Parse sampling parameters
580561
# Note: if there are preferred sampling params, we use them if they are not
581562
# explicitly passed in sampling_params
@@ -589,16 +570,20 @@ def _create_tokenized_object(
589570

590571
# Build return object
591572
if isinstance(obj, GenerateReqInput):
573+
session_params = (
574+
SessionParams(**obj.session_params) if obj.session_params else None
575+
)
576+
592577
tokenized_obj = TokenizedGenerateReqInput(
593578
obj.rid,
594579
input_text,
595580
input_ids,
596581
image_inputs,
597582
sampling_params,
598-
return_logprob,
599-
logprob_start_len,
600-
top_logprobs_num,
601-
token_ids_logprob,
583+
obj.return_logprob,
584+
obj.logprob_start_len,
585+
obj.top_logprobs_num,
586+
obj.token_ids_logprob,
602587
obj.stream,
603588
bootstrap_host=obj.bootstrap_host,
604589
bootstrap_port=obj.bootstrap_port,

python/sglang/srt/multimodal/processors/base_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(self, hf_config, server_args, _processor):
9898
self._processor = _processor
9999
self.arch = hf_config.architectures[0]
100100
self.server_args = server_args
101+
101102
# FIXME: not accurate, model and image specific
102103
self.NUM_TOKEN_PER_FRAME = 330
103104

0 commit comments

Comments
 (0)