Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 200 additions & 10 deletions python/sglang/srt/multimodal/processors/base_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import concurrent
import concurrent.futures
import dataclasses
import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -13,14 +15,8 @@
from transformers import BaseImageProcessorFast

from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import (
get_bool_env_var,
is_npu,
load_audio,
load_image,
load_video,
logger,
)
from sglang.srt.multimodal.processors.video_utils import make_video_input
from sglang.srt.utils import get_bool_env_var, is_npu, load_audio, load_image, logger
from sglang.srt.utils.cuda_ipc_transport_utils import (
MM_FEATURE_CACHE_SIZE,
CudaIpcTensorTransportProxy,
Expand Down Expand Up @@ -182,6 +178,13 @@ def __init__(
mp_context=mp.get_context("fork"),
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
)
mp_ctx_name = os.environ.get("SGLANG_MP_CTX", "spawn")
self.video_executor = concurrent.futures.ProcessPoolExecutor(
mp_context=mp.get_context(mp_ctx_name),
max_workers=int(os.environ.get("SGLANG_VIDEO_WORKERS", 8)),
# initializer=_video_worker_init,
# initargs=(int(os.environ.get("SGLANG_VIDEO_GPU_ID", "0")),),
)

# Mapping from attribute names to modality types
self.ATTR_NAME_TO_MODALITY = {
Expand Down Expand Up @@ -335,19 +338,94 @@ def _load_single_item(
if isinstance(data, dict):
return data
try:
if modality == Modality.VIDEO:
return make_video_input(
data,
frame_count_limit=frame_count_limit,
request_timeout_env="REQUEST_TIMEOUT",
)
if modality == Modality.IMAGE:
img, _ = load_image(data)
if discard_alpha_channel and img.mode != "RGB":
img = img.convert("RGB")
return img
elif modality == Modality.VIDEO:
return load_video(data, frame_count_limit)
elif modality == Modality.AUDIO:
return load_audio(data, audio_sample_rate)

except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")

def submit_data_loading_tasks_async(
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
data_iterators: dict[Modality, Iterator[Any]],
discard_alpha_channel: bool = True,
image_estimated_frames_iter: Optional[iter] = None,
image_scaling_factor: float = 1.0,
max_image_frames: int = 30,
audio_sample_rate: Optional[int] = None,
) -> Tuple[List, List]:
"""
load multimodal data parallelly using iterators.
"""
loop = asyncio.get_running_loop()
futures = []
task_info = []

for text_part in text_parts:
modality = multimodal_tokens.get_modality_of_token(text_part)
if modality is not None:
data_iterator = data_iterators.get(modality)
if data_iterator is None:
raise ValueError(f"No data iterator found for token: {text_part}")

try:
data = next(data_iterator)
except StopIteration:
raise ValueError(
f"Mismatch: More '{text_part}' tokens found than corresponding data items provided."
)

frame_count_limit = None
if modality == Modality.IMAGE and image_estimated_frames_iter:
try:
estimated_frames = next(image_estimated_frames_iter)
# Use the pre-calculated scaling factor and max frames
frame_count_limit = max(
1, int(estimated_frames * image_scaling_factor)
)
# Ensure we don't exceed the absolute max (redundant if scaling_factor handles it)
# frame_count_limit = min(frame_count_limit, max_image_frames)
except StopIteration:
raise ValueError(
"Mismatch between image tokens and estimated frame counts."
)

fn = partial(
BaseMultimodalProcessor._load_single_item,
data,
modality,
frame_count_limit,
audio_sample_rate,
discard_alpha_channel,
)
futures.append(loop.run_in_executor(self.io_executor, fn))
task_info.append((modality, data, frame_count_limit))

for modality, iterator in data_iterators.items():
try:
next(iterator)
logger.warning(
f"Warning: More {modality.name.lower()} data items provided than corresponding tokens found in the prompt."
)
except StopIteration:
pass
except Exception:
pass

return futures, task_info

def submit_data_loading_tasks(
self,
text_parts: List[str],
Expand Down Expand Up @@ -419,6 +497,118 @@ def submit_data_loading_tasks(

return futures, task_info

async def load_mm_data_async(
self,
prompt: str,
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
video_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
audio_sample_rate: Optional[int] = None,
) -> BaseMultiModalProcessorOutput:
Comment on lines +500 to +510
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between this new load_mm_data_async method and the existing load_mm_data method. The logic for processing the loaded data (from line 532 onwards) is nearly identical, with the main difference being await next(futures_iter) versus next(futures_iter).result().

To improve maintainability and reduce redundancy, consider refactoring the result-processing logic into a separate, private helper method. This helper could take the loaded data as an argument.

For example:

def _process_loaded_mm_data(self, text_parts, multimodal_tokens_pattern, task_info_iter, loaded_data):
    # ... common result processing logic ...

async def load_mm_data_async(self, ...):
    # ... submission logic ...
    results = await asyncio.gather(*futures)
    return self._process_loaded_mm_data(text_parts, multimodal_tokens, task_info_iter, results)

def load_mm_data(self, ...):
    # ... submission logic ...
    results = [f.result() for f in futures]
    return self._process_loaded_mm_data(text_parts, multimodal_tokens, task_info_iter, results)

This would make the code cleaner and easier to maintain, especially before extending this pattern to other models.

"""
Each frame of video/image will be replaced by a single image token

Args:
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
e.g. image token or audio token
discard_alpha_channel: if True, discards the alpha channel in the returned images

"""
multimodal_tokens_pattern = multimodal_tokens.get_combined_regex()

if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
else:
prompt = prompt
Comment on lines +525 to +526
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This else block is redundant as it assigns the prompt variable to itself. It can be removed to improve code clarity.


assert isinstance(prompt, str)
# split text into list of normal text and special tokens
text_parts = re.split(multimodal_tokens_pattern, prompt)

# collect all data
data_iterators = {}
if multimodal_tokens.image_token and image_data:
data_iterators[Modality.IMAGE] = iter(image_data)
if multimodal_tokens.video_token and video_data:
data_iterators[Modality.VIDEO] = iter(video_data)
if multimodal_tokens.audio_token and audio_data:
data_iterators[Modality.AUDIO] = iter(audio_data)

# futures: the futures of loaded data
# task_info: modality, raw_data, and other metadata of each data
futures, task_info = self.submit_data_loading_tasks_async(
text_parts=text_parts,
multimodal_tokens=multimodal_tokens,
data_iterators=data_iterators,
discard_alpha_channel=discard_alpha_channel,
audio_sample_rate=audio_sample_rate,
)
task_info_iter = iter(task_info)
futures_iter = iter(futures)

# Process results
images, videos, audios = [], [], []
new_text_parts = []
for text_part in text_parts:
try:
if multimodal_tokens_pattern.match(text_part):
modality, raw_data, frame_limit = next(task_info_iter)
is_precomputed = isinstance(raw_data, dict)
result = await next(futures_iter)

if modality == Modality.IMAGE:
# If data is already processed it will be a
# dictionary(precomputed). In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.image_token
)
frames = [result] if not isinstance(result, list) else result
if frames:
# only for minicpmv
images += frames
new_text_parts += mm_tokens * len(frames)
elif modality == Modality.VIDEO:
# load as video
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.video_token
)
videos += [result]
new_text_parts += mm_tokens
elif modality == Modality.AUDIO:
# audio
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.audio_token
)
audios += [result]
new_text_parts += mm_tokens
else:
# normal text
new_text_parts += [text_part]

except Exception as e:
raise RuntimeError(
f"An exception occurred while loading multimodal data: {e}"
)
return BaseMultiModalProcessorOutput(
images=images,
audios=audios,
videos=videos,
input_text="".join(new_text_parts),
)

def load_mm_data(
self,
prompt: str,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/multimodal/processors/dots_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
from sglang.srt.multimodal.processors.qwen_vl_video_worker import resize_image_async


class DotsVLMImageProcessor(BaseMultimodalProcessor):
Expand Down
6 changes: 2 additions & 4 deletions python/sglang/srt/multimodal/processors/points_v15_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from PIL import Image

from sglang.srt.models.points_v15_chat import POINTSV15ChatModel
from sglang.srt.multimodal.processors.qwen_vl import (
QwenVLImageProcessor,
resize_image_async,
)
from sglang.srt.multimodal.processors.qwen_vl import QwenVLImageProcessor
from sglang.srt.multimodal.processors.qwen_vl_video_worker import resize_image_async


class POINTSV15ChatProcessor(QwenVLImageProcessor):
Expand Down
Loading
Loading