-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Refactor vl video path to full async mode #12517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c4e4a77
7afdc47
3b55810
39b07c9
b48b0c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| import asyncio | ||
| import concurrent | ||
| import concurrent.futures | ||
| import dataclasses | ||
| import multiprocessing as mp | ||
| import os | ||
| import re | ||
| from abc import ABC, abstractmethod | ||
| from functools import partial | ||
| from typing import Any, Dict, Iterator, List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
|
|
@@ -13,14 +15,8 @@ | |
| from transformers import BaseImageProcessorFast | ||
|
|
||
| from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem | ||
| from sglang.srt.utils import ( | ||
| get_bool_env_var, | ||
| is_npu, | ||
| load_audio, | ||
| load_image, | ||
| load_video, | ||
| logger, | ||
| ) | ||
| from sglang.srt.multimodal.processors.video_utils import make_video_input | ||
| from sglang.srt.utils import get_bool_env_var, is_npu, load_audio, load_image, logger | ||
| from sglang.srt.utils.cuda_ipc_transport_utils import ( | ||
| MM_FEATURE_CACHE_SIZE, | ||
| CudaIpcTensorTransportProxy, | ||
|
|
@@ -182,6 +178,13 @@ def __init__( | |
| mp_context=mp.get_context("fork"), | ||
| max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())), | ||
| ) | ||
| mp_ctx_name = os.environ.get("SGLANG_MP_CTX", "spawn") | ||
| self.video_executor = concurrent.futures.ProcessPoolExecutor( | ||
| mp_context=mp.get_context(mp_ctx_name), | ||
| max_workers=int(os.environ.get("SGLANG_VIDEO_WORKERS", 8)), | ||
| # initializer=_video_worker_init, | ||
| # initargs=(int(os.environ.get("SGLANG_VIDEO_GPU_ID", "0")),), | ||
| ) | ||
|
|
||
| # Mapping from attribute names to modality types | ||
| self.ATTR_NAME_TO_MODALITY = { | ||
|
|
@@ -335,19 +338,94 @@ def _load_single_item( | |
| if isinstance(data, dict): | ||
| return data | ||
| try: | ||
| if modality == Modality.VIDEO: | ||
| return make_video_input( | ||
| data, | ||
| frame_count_limit=frame_count_limit, | ||
| request_timeout_env="REQUEST_TIMEOUT", | ||
| ) | ||
| if modality == Modality.IMAGE: | ||
| img, _ = load_image(data) | ||
| if discard_alpha_channel and img.mode != "RGB": | ||
| img = img.convert("RGB") | ||
| return img | ||
| elif modality == Modality.VIDEO: | ||
| return load_video(data, frame_count_limit) | ||
| elif modality == Modality.AUDIO: | ||
| return load_audio(data, audio_sample_rate) | ||
|
|
||
| except Exception as e: | ||
| raise RuntimeError(f"Error while loading data {data}: {e}") | ||
|
|
||
| def submit_data_loading_tasks_async( | ||
| self, | ||
| text_parts: List[str], | ||
| multimodal_tokens: MultimodalSpecialTokens, | ||
| data_iterators: dict[Modality, Iterator[Any]], | ||
| discard_alpha_channel: bool = True, | ||
| image_estimated_frames_iter: Optional[iter] = None, | ||
| image_scaling_factor: float = 1.0, | ||
| max_image_frames: int = 30, | ||
| audio_sample_rate: Optional[int] = None, | ||
| ) -> Tuple[List, List]: | ||
| """ | ||
| load multimodal data parallelly using iterators. | ||
| """ | ||
| loop = asyncio.get_running_loop() | ||
| futures = [] | ||
| task_info = [] | ||
|
|
||
| for text_part in text_parts: | ||
| modality = multimodal_tokens.get_modality_of_token(text_part) | ||
| if modality is not None: | ||
| data_iterator = data_iterators.get(modality) | ||
| if data_iterator is None: | ||
| raise ValueError(f"No data iterator found for token: {text_part}") | ||
|
|
||
| try: | ||
| data = next(data_iterator) | ||
| except StopIteration: | ||
| raise ValueError( | ||
| f"Mismatch: More '{text_part}' tokens found than corresponding data items provided." | ||
| ) | ||
|
|
||
| frame_count_limit = None | ||
| if modality == Modality.IMAGE and image_estimated_frames_iter: | ||
| try: | ||
| estimated_frames = next(image_estimated_frames_iter) | ||
| # Use the pre-calculated scaling factor and max frames | ||
| frame_count_limit = max( | ||
| 1, int(estimated_frames * image_scaling_factor) | ||
| ) | ||
| # Ensure we don't exceed the absolute max (redundant if scaling_factor handles it) | ||
| # frame_count_limit = min(frame_count_limit, max_image_frames) | ||
| except StopIteration: | ||
| raise ValueError( | ||
| "Mismatch between image tokens and estimated frame counts." | ||
| ) | ||
|
|
||
| fn = partial( | ||
| BaseMultimodalProcessor._load_single_item, | ||
| data, | ||
| modality, | ||
| frame_count_limit, | ||
| audio_sample_rate, | ||
| discard_alpha_channel, | ||
| ) | ||
| futures.append(loop.run_in_executor(self.io_executor, fn)) | ||
| task_info.append((modality, data, frame_count_limit)) | ||
|
|
||
| for modality, iterator in data_iterators.items(): | ||
| try: | ||
| next(iterator) | ||
| logger.warning( | ||
| f"Warning: More {modality.name.lower()} data items provided than corresponding tokens found in the prompt." | ||
| ) | ||
| except StopIteration: | ||
| pass | ||
| except Exception: | ||
| pass | ||
|
|
||
| return futures, task_info | ||
|
|
||
| def submit_data_loading_tasks( | ||
| self, | ||
| text_parts: List[str], | ||
|
|
@@ -419,6 +497,118 @@ def submit_data_loading_tasks( | |
|
|
||
| return futures, task_info | ||
|
|
||
| async def load_mm_data_async( | ||
| self, | ||
| prompt: str, | ||
| multimodal_tokens: MultimodalSpecialTokens, | ||
| image_data: Optional[list] = None, | ||
| video_data: Optional[list] = None, | ||
| audio_data: Optional[list] = None, | ||
| return_text: Optional[bool] = True, | ||
| discard_alpha_channel: bool = True, | ||
| audio_sample_rate: Optional[int] = None, | ||
| ) -> BaseMultiModalProcessorOutput: | ||
|
Comment on lines
+500
to
+510
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is significant code duplication between this new To improve maintainability and reduce redundancy, consider refactoring the result-processing logic into a separate, private helper method. This helper could take the loaded data as an argument. For example: def _process_loaded_mm_data(self, text_parts, multimodal_tokens_pattern, task_info_iter, loaded_data):
# ... common result processing logic ...
async def load_mm_data_async(self, ...):
# ... submission logic ...
results = await asyncio.gather(*futures)
return self._process_loaded_mm_data(text_parts, multimodal_tokens, task_info_iter, results)
def load_mm_data(self, ...):
# ... submission logic ...
results = [f.result() for f in futures]
return self._process_loaded_mm_data(text_parts, multimodal_tokens, task_info_iter, results)This would make the code cleaner and easier to maintain, especially before extending this pattern to other models. |
||
| """ | ||
| Each frame of video/image will be replaced by a single image token | ||
|
|
||
| Args: | ||
| multimodal_tokens (list[str]): list of special token which denoting a single multimodal data | ||
| e.g. image token or audio token | ||
| discard_alpha_channel: if True, discards the alpha channel in the returned images | ||
|
|
||
| """ | ||
| multimodal_tokens_pattern = multimodal_tokens.get_combined_regex() | ||
|
|
||
| if isinstance(prompt, list) and return_text: | ||
| assert len(prompt) and isinstance(prompt[0], int) | ||
| prompt = self._processor.tokenizer.decode(prompt) | ||
| else: | ||
| prompt = prompt | ||
|
Comment on lines
+525
to
+526
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| assert isinstance(prompt, str) | ||
| # split text into list of normal text and special tokens | ||
| text_parts = re.split(multimodal_tokens_pattern, prompt) | ||
|
|
||
| # collect all data | ||
| data_iterators = {} | ||
| if multimodal_tokens.image_token and image_data: | ||
| data_iterators[Modality.IMAGE] = iter(image_data) | ||
| if multimodal_tokens.video_token and video_data: | ||
| data_iterators[Modality.VIDEO] = iter(video_data) | ||
| if multimodal_tokens.audio_token and audio_data: | ||
| data_iterators[Modality.AUDIO] = iter(audio_data) | ||
|
|
||
| # futures: the futures of loaded data | ||
| # task_info: modality, raw_data, and other metadata of each data | ||
| futures, task_info = self.submit_data_loading_tasks_async( | ||
| text_parts=text_parts, | ||
| multimodal_tokens=multimodal_tokens, | ||
| data_iterators=data_iterators, | ||
| discard_alpha_channel=discard_alpha_channel, | ||
| audio_sample_rate=audio_sample_rate, | ||
| ) | ||
| task_info_iter = iter(task_info) | ||
| futures_iter = iter(futures) | ||
|
|
||
| # Process results | ||
| images, videos, audios = [], [], [] | ||
| new_text_parts = [] | ||
| for text_part in text_parts: | ||
| try: | ||
| if multimodal_tokens_pattern.match(text_part): | ||
| modality, raw_data, frame_limit = next(task_info_iter) | ||
| is_precomputed = isinstance(raw_data, dict) | ||
| result = await next(futures_iter) | ||
|
|
||
| if modality == Modality.IMAGE: | ||
| # If data is already processed it will be a | ||
| # dictionary(precomputed). In this case we want to keep the | ||
| # expanded tokens in text_part. Otherwise, we will | ||
| # call the processor code, so keep only a single image | ||
| # token. | ||
| mm_tokens = ( | ||
| text_part | ||
| if is_precomputed | ||
| else multimodal_tokens.image_token | ||
| ) | ||
| frames = [result] if not isinstance(result, list) else result | ||
| if frames: | ||
| # only for minicpmv | ||
| images += frames | ||
| new_text_parts += mm_tokens * len(frames) | ||
| elif modality == Modality.VIDEO: | ||
| # load as video | ||
| mm_tokens = ( | ||
| text_part | ||
| if is_precomputed | ||
| else multimodal_tokens.video_token | ||
| ) | ||
| videos += [result] | ||
| new_text_parts += mm_tokens | ||
| elif modality == Modality.AUDIO: | ||
| # audio | ||
| mm_tokens = ( | ||
| text_part | ||
| if is_precomputed | ||
| else multimodal_tokens.audio_token | ||
| ) | ||
| audios += [result] | ||
| new_text_parts += mm_tokens | ||
| else: | ||
| # normal text | ||
| new_text_parts += [text_part] | ||
|
|
||
| except Exception as e: | ||
| raise RuntimeError( | ||
| f"An exception occurred while loading multimodal data: {e}" | ||
| ) | ||
| return BaseMultiModalProcessorOutput( | ||
| images=images, | ||
| audios=audios, | ||
| videos=videos, | ||
| input_text="".join(new_text_parts), | ||
| ) | ||
|
|
||
| def load_mm_data( | ||
| self, | ||
| prompt: str, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.