Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions benchmark/mmmu/bench_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def eval_mmmu(args):
try:
# check if the model is belongs to internvl
if "InternVL" in args.model_path:
from internvl_utils import load_image
from transformers import AutoTokenizer

from sglang.srt.multimodal.internvl_utils import image_to_pixel_values

tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = AutoModel.from_pretrained(
args.model_path,
Expand Down Expand Up @@ -80,7 +81,11 @@ def eval_mmmu(args):
assert image is not None

if "InternVL" in args.model_path:
pixel_values = load_image(sample["image_path"]).to(torch.bfloat16).cuda()
image = PIL.Image.open(sample["image_path"]).convert("RGB")
pixel_values = image_to_pixel_values(
image, input_size=448, max_num=12, use_thumbnail=True
)
pixel_values = pixel_values.to(device="cuda", dtype=torch.bfloat16)
contents = ""
if prefix:
contents += prefix
Expand Down
2 changes: 2 additions & 0 deletions docs/supported_models/multimodal_language_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ in the GitHub search bar.
| **DotsVLM** (General/OCR) | `rednote-hilab/dots.vlm1.inst` | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. | |
| **DotsVLM-OCR** | `rednote-hilab/dots.ocr` | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` |
| **NVILA** (8B, 15B, Lite-2B, Lite-8B, Lite-15B) | `Efficient-Large-Model/NVILA-8B` | `chatml` | NVILA explores the full stack efficiency of multi-modal design, achieving cheaper training, faster deployment and better performance. |
| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | NVIDIA Nemotron Nano v2 VL enables multi-image reasoning and video understanding, along with strong document intelligence, visual Q&A and summarization capabilities. It builds on Nemotron Nano V2, a hybrid Mamba-Transformer LLM, in order to achieve higher inference throughput in long document and video scenarios. | Use `--trust-remote-code`. You may need to adjust `--max-mamba-cache-size` [default is 512] to fit memory constraints. |
| **JetVLM** | | JetVLM is an vision-language model designed for high-performance multimodal understanding and generation tasks built upon Jet-Nemotron. | Coming soon |

## Video Input Support
Expand All @@ -57,6 +58,7 @@ SGLang supports video input for Vision-Language Models (VLMs), enabling temporal
| **GLM-4v** (4.5V, 4.1V, MOE) | `zai-org/GLM-4.5V` | Video clips are read with Decord, converted to tensors, and passed to the model alongside metadata for rotary-position handling. |
| **NVILA** (Full & Lite) | `Efficient-Large-Model/NVILA-8B` | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. |
| **LLaVA video variants** (LLaVA-NeXT-Video, LLaVA-OneVision) | `lmms-lab/LLaVA-NeXT-Video-7B` | The processor routes video prompts to the LlavaVid video-enabled architecture, and the provided example shows how to query it with `sgl.video(...)` clips. |
| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | For video, the processor is configured to sample at 2 FPS, at a max of 128 frames, as per model training. |
| **JetVLM** | | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. |

Use `sgl.video(path, num_frames)` when building prompts to attach clips from your SGLang programs.
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config
from sglang.srt.configs.nemotron_h import NemotronHConfig
from sglang.srt.configs.olmo3 import Olmo3Config
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
Expand Down Expand Up @@ -40,6 +41,7 @@
"DotsOCRConfig",
"FalconH1Config",
"NemotronHConfig",
"NemotronH_Nano_VL_V2_Config",
"JetNemotronConfig",
"JetVLMConfig",
]
1 change: 1 addition & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"Mistral3ForConditionalGeneration",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"NemotronH_Nano_VL_V2",
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
Expand Down
114 changes: 114 additions & 0 deletions python/sglang/srt/configs/nano_nemotron_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/blob/cb5a65ff10232128389d882d805fa609427544f1/configuration.py

from typing import Any

from transformers.configuration_utils import PretrainedConfig

from sglang.srt.configs.nemotron_h import NemotronHConfig
from sglang.srt.configs.radio import RadioConfig
from sglang.srt.multimodal.internvl_utils import IMAGENET_MEAN, IMAGENET_STD


def float_triplet(seq: Any):
a, b, c = tuple(seq)
assert (
isinstance(a, float) and isinstance(b, float) and isinstance(c, float)
), "expected three floats"
return a, b, c


class NemotronH_Nano_VL_V2_Config(PretrainedConfig):
model_type = "NemotronH_Nano_VL_V2"
is_composition = True

def __init__(
self,
vision_config=None,
llm_config=None,
force_image_size: int = 512,
patch_size: int = 16,
downsample_ratio=0.5,
template=None,
ps_version="v2",
image_tag_type="internvl",
projector_hidden_size=4096,
vit_hidden_size=1280,
video_pruning_rate: float = 0.0,
video_context_token: str = "<video>",
img_context_token: str = "<image>",
img_start_token: str = "<img>",
img_end_token: str = "</img>",
norm_mean: tuple[float, float, float] | list[float] = IMAGENET_MEAN,
norm_std: tuple[float, float, float] | list[float] = IMAGENET_STD,
use_thumbnail: bool = True,
**kwargs,
):
super().__init__(**kwargs)

# Handle both cases: when loading from JSON (llm_config is dict) and when called internally by transformers (llm_config; vision_config are None)
if llm_config is not None:
self.llm_config = NemotronHConfig(**llm_config)
assert isinstance(vision_config, dict), "vision_config must be a dictionary"
self.raw_vision_config = vision_config
else:
assert vision_config is None
self.llm_config = NemotronHConfig()
self.raw_vision_config = {}

# Assign configuration values
vision_image_size = self.raw_vision_config.get("image_size", force_image_size)
vision_patch_size = self.raw_vision_config.get("patch_size", patch_size)
self.image_size = int(
vision_image_size[0]
if isinstance(vision_image_size, list)
else vision_image_size
)
self.patch_size = int(
vision_patch_size[0]
if isinstance(vision_patch_size, list)
else vision_patch_size
)

self.downsample_ratio = downsample_ratio
self.video_context_token = video_context_token
self.img_context_token = img_context_token
self.template = template # TODO move out of here and into the tokenizer
self.ps_version = ps_version # Pixel shuffle version
self.image_tag_type = image_tag_type # TODO: into the tokenizer too?
self.projector_hidden_size = projector_hidden_size
self.vit_hidden_size = vit_hidden_size
self.video_pruning_rate = video_pruning_rate

self.norm_mean = float_triplet(norm_mean)
self.norm_std = float_triplet(norm_std)
self.use_thumbnail = use_thumbnail
self.img_start_token = img_start_token
self.img_end_token = img_end_token

def create_radio_config(self):
config = self.raw_vision_config
model_name = config["args"]["model"]
reg_tokens = config["args"].get("register_multiple")
image_size = config.get("preferred_resolution", [224])[0]
radio_config = RadioConfig(
patch_size=self.patch_size,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
model_name=model_name,
reg_tokens=reg_tokens,
image_size=image_size,
)
return radio_config
106 changes: 106 additions & 0 deletions python/sglang/srt/configs/radio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/radio.py

"""Radio vision model configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

VIT_TIMM_DIM_BY_NAME: dict[str, tuple[int, int, int, int]] = {
"vit_small_patch16_224": (384, 12, 6, 1536),
"vit_base_patch16_224": (768, 12, 12, 3072),
"vit_large_patch16_224": (1024, 24, 16, 4096),
"vit_huge_patch16_224": (1280, 32, 16, 5120),
}

OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)


class RadioConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a Radio
vision model. It is used to instantiate a Radio model according to the
specified arguments, defining the model architecture.

Args:
model_name: Name of the vision transformer model
(e.g., "vit_base_patch16_224"). Used to determine architecture
dimensions from `VIT_TIMM_DIM_BY_NAME`.
image_size: The size (resolution) of each image.
patch_size: The size (resolution) of each patch.
qkv_bias: Whether to add a bias to the queries, keys and values.
qk_normalization: Whether to apply normalization to queries and keys.
norm_type: The normalization type to use.
layer_norm_eps: The epsilon used by the layer normalization layers.
initializer_factor: A factor for initializing all weight matrices.
hidden_act: The non-linear activation function in the encoder.
max_img_size: Maximum image size for position embeddings.
norm_mean: Mean values for image normalization (RGB channels).
Defaults to (0.48145466, 0.4578275, 0.40821073)).
norm_std: Standard deviation values for image normalization
(RGB channels). Defaults to (0.26862954, 0.26130258, 0.27577711)).
reg_tokens: Number of register tokens to use.
"""

model_type = "radio"

def __init__(
self,
model_name: str,
image_size: int = 224,
patch_size: int = 16,
qkv_bias: bool = True,
qk_normalization: bool = False,
norm_type: str = "layer_norm",
layer_norm_eps: float = 1e-6,
initializer_factor: float = 1.0,
hidden_act: str = "gelu",
max_img_size: int = 2048,
norm_mean: tuple[float, float, float] | list = OPENAI_CLIP_MEAN,
norm_std: tuple[float, float, float] | list = OPENAI_CLIP_STD,
reg_tokens: int | None = None,
drop_path_rate: float = 0.0,
dropout: float = 0.0,
**kwargs,
):
self.model_name = model_name
(
self.hidden_size,
self.num_hidden_layers,
self.num_attention_heads,
self.intermediate_size,
) = VIT_TIMM_DIM_BY_NAME[model_name]
self.image_size = image_size
self.patch_size = patch_size
self.qkv_bias = qkv_bias
self.qk_normalization = qk_normalization
self.norm_type = norm_type
self.layer_norm_eps = layer_norm_eps
self.initializer_factor = initializer_factor
self.hidden_act = hidden_act
self.max_img_size = max_img_size
self.norm_mean = (
list(norm_mean) if isinstance(norm_mean, (tuple, list)) else norm_mean
)
self.norm_std = (
list(norm_std) if isinstance(norm_std, (tuple, list)) else norm_std
)
self.reg_tokens = reg_tokens
self.drop_path_rate = drop_path_rate
self.dropout = dropout
super().__init__(**kwargs)
3 changes: 3 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
JetNemotronConfig,
JetVLMConfig,
KimiLinearConfig,
NemotronH_Nano_VL_V2_Config,
NemotronHConfig,
Qwen3NextConfig,
)
Expand Down Expand Up @@ -1474,6 +1475,8 @@ def mamba2_config(self):
config = self.model_config.hf_config
if isinstance(config, FalconH1Config | NemotronHConfig):
return config
if isinstance(config, NemotronH_Nano_VL_V2_Config):
return config.llm_config
return None

@property
Expand Down
Loading
Loading