Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/sglang/srt/disaggregation/encode_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,13 @@ def _load_single_item(
return data
try:
if modality == Modality.IMAGE:
img, _ = load_image(data)
if discard_alpha_channel and img.mode != "RGB":
img, _ = load_image(data, False)
if (
discard_alpha_channel
and not isinstance(img, torch.Tensor)
and img.mode != "RGB"
):
# Needed only when `img` is a PIL image
img = img.convert("RGB")
return img
elif modality == Modality.VIDEO:
Expand Down
20 changes: 14 additions & 6 deletions python/sglang/srt/multimodal/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def get_combined_regex(self) -> re.Pattern:

class BaseMultimodalProcessor(ABC):
models = []
gpu_image_decode = True # Enable GPU decoding by default

def __init__(
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
Expand Down Expand Up @@ -468,8 +469,9 @@ def get_estimated_frames_list(self, image_data):

return estimated_frames_list

@staticmethod
@classmethod
def _load_single_item(
cls,
data,
modality: Modality,
frame_count_limit=None,
Expand All @@ -481,7 +483,8 @@ def _load_single_item(

If data is processor_output or precomputed embedding, return directly.

Static method that can be pickled for multiprocessing"""
Class method that can be pickled for multiprocessing
"""
if isinstance(data, dict):
data_format = data.get("format")
if data_format in (
Expand All @@ -493,8 +496,13 @@ def _load_single_item(
return data
try:
if modality == Modality.IMAGE:
img, _ = load_image(data)
if discard_alpha_channel and img.mode != "RGB":
img, _ = load_image(data, cls.gpu_image_decode)
if (
discard_alpha_channel
and not isinstance(img, torch.Tensor)
and img.mode != "RGB"
):
# Needed only when `img` is a PIL image
img = img.convert("RGB")
return img
elif modality == Modality.VIDEO:
Expand Down Expand Up @@ -535,7 +543,7 @@ def _submit_mm_data_loading_tasks_simple(
type(data),
)
future = self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
self.__class__._load_single_item,
data,
modality,
None, # frame_count_limit: no consider for fast path
Expand Down Expand Up @@ -595,7 +603,7 @@ def submit_data_loading_tasks(

futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
self.__class__._load_single_item,
data,
modality,
frame_count_limit,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/multimodal/processors/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

class InternVLProcessor(BaseMultimodalProcessor):
models = [InternVLChatModel, InternS1ForConditionalGeneration]
gpu_image_decode = False # InternVL HF processor does not support tensor inputs

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/multimodal/processors/kimi_k25.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Compatible with KimiVLForConditionalGeneration
class KimiK2_5VLImageProcessor(SGLangBaseProcessor):
models = [KimiK25ForConditionalGeneration]
gpu_image_decode = False # KimiK2.5VL HF processor does not support tensor inputs

def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/multimodal/processors/kimi_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# Compatible with KimiVLForConditionalGeneration
class KimiVLImageProcessor(SGLangBaseProcessor):
models = [KimiVLForConditionalGeneration]
gpu_image_decode = False # KimiVL HF processor does not support tensor inputs

def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/multimodal/processors/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
LlavaQwenForCausalLM,
LlavaMistralForCausalLM,
]
gpu_image_decode = False # Llava processes loaded image as PIL image explicitly

def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
Expand All @@ -49,7 +50,7 @@ def _process_single_image_task(

try:
url = image_data.url if isinstance(image_data, ImageData) else image_data
image, image_size = load_image(url)
image, image_size = load_image(url, False)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(url)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/multimodal/processors/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
models = [MiniCPMV, MiniCPMO]
support_dynamic_frame_expansion = True
gpu_image_decode = False # MiniCPM HF processor does not support tensor inputs

def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/multimodal/processors/nano_nemotron_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@

class NanoNemotronVLImageProcessor(BaseMultimodalProcessor):
models = [NemotronH_Nano_VL_V2]
gpu_image_decode = (
False # NanoNemotronVL processes loaded image as PIL image explicitly
)

def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/multimodal/processors/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class PixtralProcessor(BaseMultimodalProcessor):
models = [PixtralVisionModel, PixtralForConditionalGeneration]
gpu_image_decode = False # Pixtral processes loaded image as PIL image explicitly

PAD_TOKEN = "<pad>"
DEFAULT_IMAGE_TOKEN = "[IMG]"
Expand Down
120 changes: 83 additions & 37 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from torch import nn
from torch.library import Library
from torch.utils._contextlib import _DecoratorContextManager
from torchvision.io import decode_jpeg
from typing_extensions import Literal

from sglang.srt.environ import envs
Expand Down Expand Up @@ -763,64 +764,109 @@ class ImageData:
max_dynamic_patch: Optional[int] = None


image_extension_names = (".png", ".jpg", ".jpeg", ".webp", ".gif")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need a mm_utils.py in this folder after this PR
cc @yhyang201

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it



def is_jpeg_with_cuda(image_bytes: bytes = b"", gpu_image_decode: bool = True) -> bool:
"""
Check three conditions:
1. whether CUDA is available.
2. whether input is recognized as JPEG.
3. whether GPU image decode is enabled (some models such as CPM forcibly disable this).
"""
if not is_cuda() or not gpu_image_decode:
return False
if image_bytes != b"":
return image_bytes.startswith(b"\xff\xd8") and image_bytes.endswith(b"\xff\xd9")
return False


def _load_image(
image_bytes: bytes = b"",
image_file: str = "",
gpu_image_decode: bool = True,
) -> Union[torch.Tensor, Image.Image]:
"""
Try to decode JPEG with nvJPEG on GPU and return a torch device tensor,
otherwise fallback to decode with PIL on CPU and return a PIL Image.
Keep the fallback path since nvJPEG may fail on some JPEG images that are not strictly compliant with the standard, while PIL is more tolerant.
"""
if image_file != "":
image_bytes = get_image_bytes(image_file)
if is_jpeg_with_cuda(image_bytes, gpu_image_decode):
try:
encoded_image = torch.frombuffer(image_bytes, dtype=torch.uint8)
image_tensor = decode_jpeg(encoded_image, device="cuda")
return image_tensor
except Exception as e:
logger.warning(
f"Failed to decode JPEG on GPU, falling back to CPU. Error: {e}"
)
return Image.open(BytesIO(image_bytes))


def load_image(
image_file: Union[Image.Image, str, ImageData, bytes],
) -> tuple[Image.Image, tuple[int, int]]:
gpu_image_decode: bool = True,
) -> tuple[Union[torch.Tensor, Image.Image], Optional[tuple[int, int]]]:
"""
Load image from multiple input formats, including:
ImageData, PIL Image, bytes, URL, file path, or base64 string.
"""
if isinstance(image_file, ImageData):
image_file = image_file.url

image = image_size = None
image = None
image_size: Optional[tuple[int, int]] = None
if isinstance(image_file, Image.Image):
image = image_file
image_size = (image.width, image.height)
elif isinstance(image_file, bytes):
image = Image.open(BytesIO(image_file))
elif image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, stream=True, timeout=timeout)
try:
response.raise_for_status()
image = Image.open(response.raw)
image.load() # Force loading to avoid issues after closing the stream
finally:
response.close()
elif image_file.startswith("file://"):
image_file = unquote(urlparse(image_file).path)
image = Image.open(image_file)
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
image = Image.open(image_file)
elif image_file.startswith("data:"):
image_file = image_file.split(",")[1]
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
elif isinstance(image_file, str):
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
image = _load_image(image_bytes=image_file, gpu_image_decode=gpu_image_decode)
elif isinstance(image_file, str) and image_file.startswith(("http://", "https://")):
image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode)
elif isinstance(image_file, str) and image_file.startswith("file://"):
image = _load_image(
image_file=unquote(urlparse(image_file).path),
gpu_image_decode=gpu_image_decode,
)
elif isinstance(image_file, str) and image_file.lower().endswith(
image_extension_names
):
image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode)
elif isinstance(image_file, str) and image_file.startswith("data:"):
image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode)
elif isinstance(
image_file, str
): # Other formats, try to decode as base64 by default
image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode)
else:
raise ValueError(f"Invalid image: {image_file}")

return image, image_size


def get_image_bytes(image_file: Union[str, bytes]):
def get_image_bytes(image_file: Union[str, bytes]) -> bytes:
"""Normalize various image inputs into raw bytes."""
if isinstance(image_file, bytes):
return image_file
elif image_file.startswith("http://") or image_file.startswith("https://"):
if image_file.startswith(("http://", "https://")):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout)
return response.content
elif image_file.startswith("file://"):
image_file = unquote(urlparse(image_file).path)
with open(image_file, "rb") as f:
return f.read()
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
try:
response.raise_for_status()
result = response.content
finally:
response.close()
return result
if image_file.startswith(("file://", "/")):
with open(image_file, "rb") as f:
return f.read()
elif image_file.startswith("data:"):
image_file = image_file.split(",")[1]
if isinstance(image_file, str) and image_file.startswith("data:"):
_, encoded = image_file.split(",", 1)
return pybase64.b64decode(encoded, validate=True)
if isinstance(image_file, str):
return pybase64.b64decode(image_file, validate=True)
elif isinstance(image_file, str):
return pybase64.b64decode(image_file, validate=True)
else:
raise NotImplementedError(f"Invalid image: {image_file}")
raise NotImplementedError(f"Invalid image: {image_file}")


def _normalize_video_input(
Expand Down
Loading