diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 4da24e522c98..02d1054b48e4 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -435,8 +435,13 @@ def _load_single_item( return data try: if modality == Modality.IMAGE: - img, _ = load_image(data) - if discard_alpha_channel and img.mode != "RGB": + img, _ = load_image(data, False) + if ( + discard_alpha_channel + and not isinstance(img, torch.Tensor) + and img.mode != "RGB" + ): + # Needed only when `img` is a PIL image img = img.convert("RGB") return img elif modality == Modality.VIDEO: diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 859c225be13d..32847bf74af8 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -173,6 +173,7 @@ def get_combined_regex(self) -> re.Pattern: class BaseMultimodalProcessor(ABC): models = [] + gpu_image_decode = True # Enable GPU decoding by default def __init__( self, hf_config, server_args, _processor, transport_mode, *args, **kwargs @@ -468,8 +469,9 @@ def get_estimated_frames_list(self, image_data): return estimated_frames_list - @staticmethod + @classmethod def _load_single_item( + cls, data, modality: Modality, frame_count_limit=None, @@ -481,7 +483,8 @@ def _load_single_item( If data is processor_output or precomputed embedding, return directly. - Static method that can be pickled for multiprocessing""" + Class method that can be pickled for multiprocessing + """ if isinstance(data, dict): data_format = data.get("format") if data_format in ( @@ -493,8 +496,13 @@ def _load_single_item( return data try: if modality == Modality.IMAGE: - img, _ = load_image(data) - if discard_alpha_channel and img.mode != "RGB": + img, _ = load_image(data, cls.gpu_image_decode) + if ( + discard_alpha_channel + and not isinstance(img, torch.Tensor) + and img.mode != "RGB" + ): + # Needed only when `img` is a PIL image img = img.convert("RGB") return img elif modality == Modality.VIDEO: @@ -535,7 +543,7 @@ def _submit_mm_data_loading_tasks_simple( type(data), ) future = self.io_executor.submit( - BaseMultimodalProcessor._load_single_item, + self.__class__._load_single_item, data, modality, None, # frame_count_limit: no consider for fast path @@ -595,7 +603,7 @@ def submit_data_loading_tasks( futures.append( self.io_executor.submit( - BaseMultimodalProcessor._load_single_item, + self.__class__._load_single_item, data, modality, frame_count_limit, diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index 955198730e92..90624f86f122 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -27,6 +27,7 @@ class InternVLProcessor(BaseMultimodalProcessor): models = [InternVLChatModel, InternS1ForConditionalGeneration] + gpu_image_decode = False # InternVL HF processor does not support tensor inputs IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] diff --git a/python/sglang/srt/multimodal/processors/kimi_k25.py b/python/sglang/srt/multimodal/processors/kimi_k25.py index d8bb9ceb3a8b..cef3e6933499 100644 --- a/python/sglang/srt/multimodal/processors/kimi_k25.py +++ b/python/sglang/srt/multimodal/processors/kimi_k25.py @@ -16,6 +16,7 @@ # Compatible with KimiVLForConditionalGeneration class KimiK2_5VLImageProcessor(SGLangBaseProcessor): models = [KimiK25ForConditionalGeneration] + gpu_image_decode = False # KimiK2.5VL HF processor does not support tensor inputs def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) diff --git a/python/sglang/srt/multimodal/processors/kimi_vl.py b/python/sglang/srt/multimodal/processors/kimi_vl.py index cd7cfe2fd3ae..b466f1b40994 100644 --- a/python/sglang/srt/multimodal/processors/kimi_vl.py +++ b/python/sglang/srt/multimodal/processors/kimi_vl.py @@ -13,6 +13,7 @@ # Compatible with KimiVLForConditionalGeneration class KimiVLImageProcessor(SGLangBaseProcessor): models = [KimiVLForConditionalGeneration] + gpu_image_decode = False # KimiVL HF processor does not support tensor inputs def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) diff --git a/python/sglang/srt/multimodal/processors/llava.py b/python/sglang/srt/multimodal/processors/llava.py index 8111f03afbad..55a9fa686a18 100644 --- a/python/sglang/srt/multimodal/processors/llava.py +++ b/python/sglang/srt/multimodal/processors/llava.py @@ -33,6 +33,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): LlavaQwenForCausalLM, LlavaMistralForCausalLM, ] + gpu_image_decode = False # Llava processes loaded image as PIL image explicitly def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) @@ -49,7 +50,7 @@ def _process_single_image_task( try: url = image_data.url if isinstance(image_data, ImageData) else image_data - image, image_size = load_image(url) + image, image_size = load_image(url, False) if image_size is not None: # It is a video with multiple images image_hash = hash(url) diff --git a/python/sglang/srt/multimodal/processors/minicpm.py b/python/sglang/srt/multimodal/processors/minicpm.py index 2a375c9dabb4..bad2cbe3d027 100644 --- a/python/sglang/srt/multimodal/processors/minicpm.py +++ b/python/sglang/srt/multimodal/processors/minicpm.py @@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): models = [MiniCPMV, MiniCPMO] support_dynamic_frame_expansion = True + gpu_image_decode = False # MiniCPM HF processor does not support tensor inputs def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) diff --git a/python/sglang/srt/multimodal/processors/nano_nemotron_vl.py b/python/sglang/srt/multimodal/processors/nano_nemotron_vl.py index 83d72441f861..98986090f979 100644 --- a/python/sglang/srt/multimodal/processors/nano_nemotron_vl.py +++ b/python/sglang/srt/multimodal/processors/nano_nemotron_vl.py @@ -35,6 +35,9 @@ class NanoNemotronVLImageProcessor(BaseMultimodalProcessor): models = [NemotronH_Nano_VL_V2] + gpu_image_decode = ( + False # NanoNemotronVL processes loaded image as PIL image explicitly + ) def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs): super().__init__(hf_config, server_args, _image_processor, *args, **kwargs) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index 47b1513e8fd6..ed40fc01785f 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -19,6 +19,7 @@ class PixtralProcessor(BaseMultimodalProcessor): models = [PixtralVisionModel, PixtralForConditionalGeneration] + gpu_image_decode = False # Pixtral processes loaded image as PIL image explicitly PAD_TOKEN = "" DEFAULT_IMAGE_TOKEN = "[IMG]" diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 57d85eb2f136..d58c33affdbd 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -86,6 +86,7 @@ from torch import nn from torch.library import Library from torch.utils._contextlib import _DecoratorContextManager +from torchvision.io import decode_jpeg from typing_extensions import Literal from sglang.srt.environ import envs @@ -763,64 +764,109 @@ class ImageData: max_dynamic_patch: Optional[int] = None +image_extension_names = (".png", ".jpg", ".jpeg", ".webp", ".gif") + + +def is_jpeg_with_cuda(image_bytes: bytes = b"", gpu_image_decode: bool = True) -> bool: + """ + Check three conditions: + 1. whether CUDA is available. + 2. whether input is recognized as JPEG. + 3. whether GPU image decode is enabled (some models such as CPM forcibly disable this). + """ + if not is_cuda() or not gpu_image_decode: + return False + if image_bytes != b"": + return image_bytes.startswith(b"\xff\xd8") and image_bytes.endswith(b"\xff\xd9") + return False + + +def _load_image( + image_bytes: bytes = b"", + image_file: str = "", + gpu_image_decode: bool = True, +) -> Union[torch.Tensor, Image.Image]: + """ + Try to decode JPEG with nvJPEG on GPU and return a torch device tensor, + otherwise fallback to decode with PIL on CPU and return a PIL Image. + Keep the fallback path since nvJPEG may fail on some JPEG images that are not strictly compliant with the standard, while PIL is more tolerant. + """ + if image_file != "": + image_bytes = get_image_bytes(image_file) + if is_jpeg_with_cuda(image_bytes, gpu_image_decode): + try: + encoded_image = torch.frombuffer(image_bytes, dtype=torch.uint8) + image_tensor = decode_jpeg(encoded_image, device="cuda") + return image_tensor + except Exception as e: + logger.warning( + f"Failed to decode JPEG on GPU, falling back to CPU. Error: {e}" + ) + return Image.open(BytesIO(image_bytes)) + + def load_image( image_file: Union[Image.Image, str, ImageData, bytes], -) -> tuple[Image.Image, tuple[int, int]]: + gpu_image_decode: bool = True, +) -> tuple[Union[torch.Tensor, Image.Image], Optional[tuple[int, int]]]: + """ + Load image from multiple input formats, including: + ImageData, PIL Image, bytes, URL, file path, or base64 string. + """ if isinstance(image_file, ImageData): image_file = image_file.url - image = image_size = None + image = None + image_size: Optional[tuple[int, int]] = None if isinstance(image_file, Image.Image): image = image_file image_size = (image.width, image.height) elif isinstance(image_file, bytes): - image = Image.open(BytesIO(image_file)) - elif image_file.startswith("http://") or image_file.startswith("https://"): - timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) - response = requests.get(image_file, stream=True, timeout=timeout) - try: - response.raise_for_status() - image = Image.open(response.raw) - image.load() # Force loading to avoid issues after closing the stream - finally: - response.close() - elif image_file.startswith("file://"): - image_file = unquote(urlparse(image_file).path) - image = Image.open(image_file) - elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): - image = Image.open(image_file) - elif image_file.startswith("data:"): - image_file = image_file.split(",")[1] - image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) - elif isinstance(image_file, str): - image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) + image = _load_image(image_bytes=image_file, gpu_image_decode=gpu_image_decode) + elif isinstance(image_file, str) and image_file.startswith(("http://", "https://")): + image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode) + elif isinstance(image_file, str) and image_file.startswith("file://"): + image = _load_image( + image_file=unquote(urlparse(image_file).path), + gpu_image_decode=gpu_image_decode, + ) + elif isinstance(image_file, str) and image_file.lower().endswith( + image_extension_names + ): + image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode) + elif isinstance(image_file, str) and image_file.startswith("data:"): + image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode) + elif isinstance( + image_file, str + ): # Other formats, try to decode as base64 by default + image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode) else: raise ValueError(f"Invalid image: {image_file}") - return image, image_size -def get_image_bytes(image_file: Union[str, bytes]): +def get_image_bytes(image_file: Union[str, bytes]) -> bytes: + """Normalize various image inputs into raw bytes.""" if isinstance(image_file, bytes): return image_file - elif image_file.startswith("http://") or image_file.startswith("https://"): + if image_file.startswith(("http://", "https://")): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) response = requests.get(image_file, timeout=timeout) - return response.content - elif image_file.startswith("file://"): - image_file = unquote(urlparse(image_file).path) - with open(image_file, "rb") as f: - return f.read() - elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): + try: + response.raise_for_status() + result = response.content + finally: + response.close() + return result + if image_file.startswith(("file://", "/")): with open(image_file, "rb") as f: return f.read() - elif image_file.startswith("data:"): - image_file = image_file.split(",")[1] + if isinstance(image_file, str) and image_file.startswith("data:"): + _, encoded = image_file.split(",", 1) + return pybase64.b64decode(encoded, validate=True) + if isinstance(image_file, str): return pybase64.b64decode(image_file, validate=True) - elif isinstance(image_file, str): - return pybase64.b64decode(image_file, validate=True) - else: - raise NotImplementedError(f"Invalid image: {image_file}") + raise NotImplementedError(f"Invalid image: {image_file}") def _normalize_video_input(