From 73ee71203b1ad02fded6706dfd3ec605b7bae35d Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Tue, 18 Nov 2025 11:52:56 +0200 Subject: [PATCH 1/6] modularize gsm8 kand mmmu test classes --- python/sglang/test/gsm8k_mixin.py | 44 ++++ python/sglang/test/mmmu_vlm_mixin.py | 233 ++++++++++++++++++ .../models/test_nvidia_nemotron_nano_v2.py | 54 +--- test/srt/models/test_vlm_models.py | 231 +---------------- 4 files changed, 291 insertions(+), 271 deletions(-) create mode 100644 python/sglang/test/gsm8k_mixin.py create mode 100644 python/sglang/test/mmmu_vlm_mixin.py diff --git a/python/sglang/test/gsm8k_mixin.py b/python/sglang/test/gsm8k_mixin.py new file mode 100644 index 000000000000..dc09d8a55d5e --- /dev/null +++ b/python/sglang/test/gsm8k_mixin.py @@ -0,0 +1,44 @@ +from abc import ABC +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class GSM8KMixin(ABC): + accuracy: float + model: str + other_args: list[str] = [] + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreaterEqual(metrics["accuracy"], self.accuracy) diff --git a/python/sglang/test/mmmu_vlm_mixin.py b/python/sglang/test/mmmu_vlm_mixin.py new file mode 100644 index 000000000000..e621a00873f5 --- /dev/null +++ b/python/sglang/test/mmmu_vlm_mixin.py @@ -0,0 +1,233 @@ +import glob +import json +import os +import subprocess +from abc import ABC +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + +# Set default mem_fraction_static to 0.8 +DEFAULT_MEM_FRACTION_STATIC = 0.8 + + +class MMMUVLMMixin(ABC): + parsed_args = None # Class variable to store args + other_args = [] + mmmu_args = [] + + @classmethod + def setUpClass(cls): + # Removed argument parsing from here + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + + if cls.parsed_args is None: + cls.parsed_args = SimpleNamespace( + mem_fraction_static=DEFAULT_MEM_FRACTION_STATIC + ) + + # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work. + os.environ["OPENAI_API_KEY"] = cls.api_key + os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1" + + def run_mmmu_eval( + self, + model_version: str, + output_path: str, + *, + env: dict | None = None, + ): + """ + Evaluate a VLM on the MMMU validation set with lmms‑eval. + Only `model_version` (checkpoint) and `chat_template` vary; + We are focusing only on the validation set due to resource constraints. + """ + # -------- fixed settings -------- + model = "openai_compatible" + tp = 1 + tasks = "mmmu_val" + batch_size = 32 + log_suffix = "openai_compatible" + os.makedirs(output_path, exist_ok=True) + + # -------- compose --model_args -------- + model_args = f'model_version="{model_version}",' f"tp={tp}" + + # -------- build command list -------- + cmd = [ + "python3", + "-m", + "lmms_eval", + "--model", + model, + "--model_args", + model_args, + "--tasks", + tasks, + "--batch_size", + str(batch_size), + "--log_samples", + "--log_samples_suffix", + log_suffix, + "--output_path", + str(output_path), + *self.mmmu_args, + ] + + subprocess.run( + cmd, + check=True, + timeout=3600, + ) + + def _run_vlm_mmmu_test( + self, + model, + output_path, + test_name="", + custom_env=None, + log_level="info", + capture_output=False, + ): + """ + Common method to run VLM MMMU benchmark test. + + Args: + model: Model to test + output_path: Path for output logs + test_name: Optional test name for logging + custom_env: Optional custom environment variables + log_level: Log level for server (default: "info") + capture_output: Whether to capture server stdout/stderr + """ + print(f"\nTesting model: {model.model}{test_name}") + + process = None + mmmu_accuracy = 0 # Initialize to handle potential exceptions + server_output = "" + + try: + # Prepare environment variables + process_env = os.environ.copy() + if custom_env: + process_env.update(custom_env) + # if test vlm with cuda_ipc feature, open this env_var + process_env["SGLANG_USE_CUDA_IPC_TRANSPORT"] = "1" + + # Prepare stdout/stderr redirection if needed + stdout_file = None + stderr_file = None + if capture_output: + stdout_file = open("/tmp/server_stdout.log", "w") + stderr_file = open("/tmp/server_stderr.log", "w") + + # Launch server for testing + process = popen_launch_server( + model.model, + base_url=self.base_url, + timeout=self.time_out, + api_key=self.api_key, + other_args=[ + "--trust-remote-code", + "--cuda-graph-max-bs", + "32", + "--enable-multimodal", + "--mem-fraction-static", + str(self.parsed_args.mem_fraction_static), # Use class variable + "--log-level", + log_level, + *self.other_args, + ], + env=process_env, + return_stdout_stderr=( + (stdout_file, stderr_file) if capture_output else None + ), + ) + + # Run evaluation + self.run_mmmu_eval(model.model, output_path) + + # Get the result file + # Search recursively for JSON result files (lmms-eval v0.4.1+ creates subdirectories) + result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) + if not result_files: + result_files = glob.glob(f"{output_path}/*.json") + + if not result_files: + raise FileNotFoundError(f"No JSON result files found in {output_path}") + + result_file_path = result_files[0] + + with open(result_file_path, "r") as f: + result = json.load(f) + print(f"Result{test_name}\n: {result}") + + # Process the result + mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] + print( + f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}" + ) + + # Capture server output if requested + if capture_output and process: + server_output = self._read_output_from_files() + + # Assert performance meets expected threshold + self.assertGreaterEqual( + mmmu_accuracy, + model.mmmu_accuracy, + f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}", + ) + + return server_output + + except Exception as e: + print(f"Error testing {model.model}{test_name}: {e}") + self.fail(f"Test failed for {model.model}{test_name}: {e}") + + finally: + # Ensure process cleanup happens regardless of success/failure + if process is not None and process.poll() is None: + print(f"Cleaning up process {process.pid}") + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + # clean up temporary files + if capture_output: + if stdout_file: + stdout_file.close() + if stderr_file: + stderr_file.close() + for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]: + try: + if os.path.exists(filename): + os.remove(filename) + except Exception as e: + print(f"Error removing {filename}: {e}") + + def _read_output_from_files(self): + output_lines = [] + + log_files = [ + ("/tmp/server_stdout.log", "[STDOUT]"), + ("/tmp/server_stderr.log", "[STDERR]"), + ] + for filename, tag in log_files: + try: + if os.path.exists(filename): + with open(filename, "r") as f: + for line in f: + output_lines.append(f"{tag} {line.rstrip()}") + except Exception as e: + print(f"Error reading {tag.lower()} file: {e}") + + return "\n".join(output_lines) diff --git a/test/srt/models/test_nvidia_nemotron_nano_v2.py b/test/srt/models/test_nvidia_nemotron_nano_v2.py index 4b414fbacf8a..d0b3ab1177f3 100644 --- a/test/srt/models/test_nvidia_nemotron_nano_v2.py +++ b/test/srt/models/test_nvidia_nemotron_nano_v2.py @@ -1,61 +1,27 @@ import unittest -from types import SimpleNamespace -from sglang.srt.utils import is_blackwell, kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) +from sglang.srt.utils import is_blackwell +from sglang.test.gsm8k_mixin import GSM8KMixin +from sglang.test.test_utils import CustomTestCase -class TestNvidiaNemotronNanoV2(CustomTestCase): +class TestNvidiaNemotronNanoV2BF16(GSM8KMixin, CustomTestCase): model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" accuracy = 0.87 + other_args = ["--max-mamba-cache-size", "256"] - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--max-mamba-cache-size", - "256", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreaterEqual(metrics["accuracy"], self.accuracy) - - -class TestNvidiaNemotronNanoV2FP8(TestNvidiaNemotronNanoV2): + +class TestNvidiaNemotronNanoV2FP8(GSM8KMixin, CustomTestCase): accuracy = 0.87 model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8" + other_args = ["--max-mamba-cache-size", "256"] @unittest.skipIf(not is_blackwell(), "NVFP4 only supported on blackwell") -class TestNvidiaNemotronNanoV2NVFP4(TestNvidiaNemotronNanoV2): +class TestNvidiaNemotronNanoV2NVFP4(GSM8KMixin, CustomTestCase): accuracy = 0.855 model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4" + other_args = ["--max-mamba-cache-size", "256"] if __name__ == "__main__": diff --git a/test/srt/models/test_vlm_models.py b/test/srt/models/test_vlm_models.py index 42f1bd1b6f82..b26ff0831f7a 100644 --- a/test/srt/models/test_vlm_models.py +++ b/test/srt/models/test_vlm_models.py @@ -1,21 +1,12 @@ import argparse -import glob -import json -import os import random -import subprocess import sys import unittest from types import SimpleNamespace -from sglang.srt.utils import is_hip, kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - is_in_ci, - popen_launch_server, -) +from sglang.srt.utils import is_hip +from sglang.test.mmmu_vlm_mixin import DEFAULT_MEM_FRACTION_STATIC, MMMUVLMMixin +from sglang.test.test_utils import CustomTestCase, is_in_ci _is_hip = is_hip() # VLM models for testing @@ -28,222 +19,8 @@ SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4), ] -# Set default mem_fraction_static to 0.8 -DEFAULT_MEM_FRACTION_STATIC = 0.8 - - -class TestVLMModels(CustomTestCase): - parsed_args = None # Class variable to store args - - @classmethod - def setUpClass(cls): - # Removed argument parsing from here - cls.base_url = DEFAULT_URL_FOR_TEST - cls.api_key = "sk-123456" - cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - - if cls.parsed_args is None: - cls.parsed_args = SimpleNamespace( - mem_fraction_static=DEFAULT_MEM_FRACTION_STATIC - ) - - # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work. - os.environ["OPENAI_API_KEY"] = cls.api_key - os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1" - - def run_mmmu_eval( - self, - model_version: str, - output_path: str, - *, - env: dict | None = None, - ): - """ - Evaluate a VLM on the MMMU validation set with lmms‑eval. - Only `model_version` (checkpoint) and `chat_template` vary; - We are focusing only on the validation set due to resource constraints. - """ - # -------- fixed settings -------- - model = "openai_compatible" - tp = 1 - tasks = "mmmu_val" - batch_size = 32 - log_suffix = "openai_compatible" - os.makedirs(output_path, exist_ok=True) - - # -------- compose --model_args -------- - model_args = f'model_version="{model_version}",' f"tp={tp}" - - # -------- build command list -------- - cmd = [ - "python3", - "-m", - "lmms_eval", - "--model", - model, - "--model_args", - model_args, - "--tasks", - tasks, - "--batch_size", - str(batch_size), - "--log_samples", - "--log_samples_suffix", - log_suffix, - "--output_path", - str(output_path), - ] - - subprocess.run( - cmd, - check=True, - timeout=3600, - ) - - def _run_vlm_mmmu_test( - self, - model, - output_path, - test_name="", - custom_env=None, - log_level="info", - capture_output=False, - ): - """ - Common method to run VLM MMMU benchmark test. - - Args: - model: Model to test - output_path: Path for output logs - test_name: Optional test name for logging - custom_env: Optional custom environment variables - log_level: Log level for server (default: "info") - capture_output: Whether to capture server stdout/stderr - """ - print(f"\nTesting model: {model.model}{test_name}") - - process = None - mmmu_accuracy = 0 # Initialize to handle potential exceptions - server_output = "" - - try: - # Prepare environment variables - process_env = os.environ.copy() - if custom_env: - process_env.update(custom_env) - # if test vlm with cuda_ipc feature, open this env_var - process_env["SGLANG_USE_CUDA_IPC_TRANSPORT"] = "1" - - # Prepare stdout/stderr redirection if needed - stdout_file = None - stderr_file = None - if capture_output: - stdout_file = open("/tmp/server_stdout.log", "w") - stderr_file = open("/tmp/server_stderr.log", "w") - - # Launch server for testing - process = popen_launch_server( - model.model, - base_url=self.base_url, - timeout=self.time_out, - api_key=self.api_key, - other_args=[ - "--trust-remote-code", - "--cuda-graph-max-bs", - "32", - "--enable-multimodal", - "--mem-fraction-static", - str(self.parsed_args.mem_fraction_static), # Use class variable - "--log-level", - log_level, - ], - env=process_env, - return_stdout_stderr=( - (stdout_file, stderr_file) if capture_output else None - ), - ) - - # Run evaluation - self.run_mmmu_eval(model.model, output_path) - - # Get the result file - # Search recursively for JSON result files (lmms-eval v0.4.1+ creates subdirectories) - result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) - if not result_files: - result_files = glob.glob(f"{output_path}/*.json") - - if not result_files: - raise FileNotFoundError(f"No JSON result files found in {output_path}") - - result_file_path = result_files[0] - - with open(result_file_path, "r") as f: - result = json.load(f) - print(f"Result{test_name}\n: {result}") - - # Process the result - mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] - print( - f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}" - ) - - # Capture server output if requested - if capture_output and process: - server_output = self._read_output_from_files() - - # Assert performance meets expected threshold - self.assertGreaterEqual( - mmmu_accuracy, - model.mmmu_accuracy, - f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}", - ) - - return server_output - - except Exception as e: - print(f"Error testing {model.model}{test_name}: {e}") - self.fail(f"Test failed for {model.model}{test_name}: {e}") - - finally: - # Ensure process cleanup happens regardless of success/failure - if process is not None and process.poll() is None: - print(f"Cleaning up process {process.pid}") - try: - kill_process_tree(process.pid) - except Exception as e: - print(f"Error killing process: {e}") - - # clean up temporary files - if capture_output: - if stdout_file: - stdout_file.close() - if stderr_file: - stderr_file.close() - for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]: - try: - if os.path.exists(filename): - os.remove(filename) - except Exception as e: - print(f"Error removing {filename}: {e}") - - def _read_output_from_files(self): - output_lines = [] - - log_files = [ - ("/tmp/server_stdout.log", "[STDOUT]"), - ("/tmp/server_stderr.log", "[STDERR]"), - ] - for filename, tag in log_files: - try: - if os.path.exists(filename): - with open(filename, "r") as f: - for line in f: - output_lines.append(f"{tag} {line.rstrip()}") - except Exception as e: - print(f"Error reading {tag.lower()} file: {e}") - - return "\n".join(output_lines) +class TestVLMModels(MMMUVLMMixin, CustomTestCase): def test_vlm_mmmu_benchmark(self): """Test VLM models against MMMU benchmark.""" models_to_test = MODELS From a9db7b00759d9ed22b88a2e7047958d1158869f8 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Tue, 18 Nov 2025 12:10:28 +0200 Subject: [PATCH 2/6] support radio model --- python/sglang/srt/configs/radio.py | 106 ++++++ python/sglang/srt/models/radio.py | 532 +++++++++++++++++++++++++++++ 2 files changed, 638 insertions(+) create mode 100644 python/sglang/srt/configs/radio.py create mode 100644 python/sglang/srt/models/radio.py diff --git a/python/sglang/srt/configs/radio.py b/python/sglang/srt/configs/radio.py new file mode 100644 index 000000000000..cc6df58e0ff2 --- /dev/null +++ b/python/sglang/srt/configs/radio.py @@ -0,0 +1,106 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/radio.py + +"""Radio vision model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VIT_TIMM_DIM_BY_NAME: dict[str, tuple[int, int, int, int]] = { + "vit_small_patch16_224": (384, 12, 6, 1536), + "vit_base_patch16_224": (768, 12, 12, 3072), + "vit_large_patch16_224": (1024, 24, 16, 4096), + "vit_huge_patch16_224": (1280, 32, 16, 5120), +} + +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) + + +class RadioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a Radio + vision model. It is used to instantiate a Radio model according to the + specified arguments, defining the model architecture. + + Args: + model_name: Name of the vision transformer model + (e.g., "vit_base_patch16_224"). Used to determine architecture + dimensions from `VIT_TIMM_DIM_BY_NAME`. + image_size: The size (resolution) of each image. + patch_size: The size (resolution) of each patch. + qkv_bias: Whether to add a bias to the queries, keys and values. + qk_normalization: Whether to apply normalization to queries and keys. + norm_type: The normalization type to use. + layer_norm_eps: The epsilon used by the layer normalization layers. + initializer_factor: A factor for initializing all weight matrices. + hidden_act: The non-linear activation function in the encoder. + max_img_size: Maximum image size for position embeddings. + norm_mean: Mean values for image normalization (RGB channels). + Defaults to (0.48145466, 0.4578275, 0.40821073)). + norm_std: Standard deviation values for image normalization + (RGB channels). Defaults to (0.26862954, 0.26130258, 0.27577711)). + reg_tokens: Number of register tokens to use. + """ + + model_type = "radio" + + def __init__( + self, + model_name: str, + image_size: int = 224, + patch_size: int = 16, + qkv_bias: bool = True, + qk_normalization: bool = False, + norm_type: str = "layer_norm", + layer_norm_eps: float = 1e-6, + initializer_factor: float = 1.0, + hidden_act: str = "gelu", + max_img_size: int = 2048, + norm_mean: tuple[float, float, float] | list = OPENAI_CLIP_MEAN, + norm_std: tuple[float, float, float] | list = OPENAI_CLIP_STD, + reg_tokens: int | None = None, + drop_path_rate: float = 0.0, + dropout: float = 0.0, + **kwargs, + ): + self.model_name = model_name + ( + self.hidden_size, + self.num_hidden_layers, + self.num_attention_heads, + self.intermediate_size, + ) = VIT_TIMM_DIM_BY_NAME[model_name] + self.image_size = image_size + self.patch_size = patch_size + self.qkv_bias = qkv_bias + self.qk_normalization = qk_normalization + self.norm_type = norm_type + self.layer_norm_eps = layer_norm_eps + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.max_img_size = max_img_size + self.norm_mean = ( + list(norm_mean) if isinstance(norm_mean, (tuple, list)) else norm_mean + ) + self.norm_std = ( + list(norm_std) if isinstance(norm_std, (tuple, list)) else norm_std + ) + self.reg_tokens = reg_tokens + self.drop_path_rate = drop_path_rate + self.dropout = dropout + super().__init__(**kwargs) diff --git a/python/sglang/srt/models/radio.py b/python/sglang/srt/models/radio.py new file mode 100644 index 000000000000..2cd233141c15 --- /dev/null +++ b/python/sglang/srt/models/radio.py @@ -0,0 +1,532 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/radio.py + +import math +from collections.abc import Iterable +from itertools import repeat +from typing import TypeAlias + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutput + +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + replace_prefix, + replace_substrings, +) +from sglang.srt.models.internvl import InternVisionEncoder + +input_dim_t: TypeAlias = int | tuple[int, int] +norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor + + +def _ntuple(n): + def parse(x): + if isinstance(x, Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +class ClsToken(nn.Module): + def __init__( + self, + ndim: int, + num_tokens: int = 1, + enabled: bool = True, + register_multiple: int | None = None, + num_registers: int | None = None, + ): + super().__init__() + + self.ndim = ndim + self.enabled = enabled + self.num_registers = 0 + self.num_tokens = num_tokens + if enabled: + if num_registers: + self.num_registers = num_registers + elif register_multiple: + self.num_registers = register_multiple - ( + num_tokens % register_multiple + ) + + scale = ndim**-0.5 + self.token = nn.Parameter( + torch.randn(num_tokens + self.num_registers, ndim) * scale + ) + + else: + self.token = None + + self.num_patches = self.num_tokens + self.num_registers + + def forward(self, x: torch.Tensor): + if self.token is None: + return x + + token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) + x = torch.cat( + [ + token, + x, + ], + dim=1, + ) + + return x + + +class ViTPatchGenerator(nn.Module): + def __init__( + self, + # config: PretrainedConfig, + patch_size: int, + embed_dim: int, + input_dims: input_dim_t, + abs_pos: bool = True, + normalize_patches: bool = False, + cls_token: bool = False, + max_input_dims: input_dim_t | None = None, + pos_dropout: float = 0.0, + return_pos_enc: bool = False, + num_cls_tokens: int = 1, + register_multiple: int | None = None, + num_registers: int | None = None, + patch_bias: bool = False, + device=None, + dtype=None, + ): + super().__init__() + if isinstance(input_dims, int): + input_dims = (input_dims, input_dims) + + if max_input_dims is None: + max_input_dims = input_dims + if isinstance(max_input_dims, int): + max_input_dims = (max_input_dims, max_input_dims) + + max_input_dims = tuple( + int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims + ) + + self.cpe_mode = max_input_dims != input_dims + self.pos_dropout = pos_dropout + self.return_pos_enc = return_pos_enc + + factory = dict(device=device, dtype=dtype) + + self.patch_size = patch_size + self.abs_pos = abs_pos + self.embed_dim = embed_dim + + self.num_rows = max_input_dims[0] // patch_size + self.num_cols = max_input_dims[1] // patch_size + self.input_dims = tuple(d // patch_size for d in input_dims) + self.num_patches = self.num_rows * self.num_cols + self.max_input_dims = max_input_dims + + self.im_to_patches = Im2Patches(patch_size) + self.embedder = ViTPatchLinear( + patch_size, embed_dim, bias=patch_bias, **factory + ) + + if abs_pos: + scale = embed_dim**-0.5 + self.pos_embed = nn.Parameter( + torch.randn(1, self.num_patches, embed_dim, **factory) * scale + ) + + self.cls_token = ClsToken( + embed_dim, + num_tokens=num_cls_tokens, + enabled=cls_token, + register_multiple=register_multiple, + num_registers=num_registers, + ) + + self.patch_normalizer = ( + nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + patches = self.embed_patches(x) + patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:]) + patches = self.cls_token(patches) + patches = self.patch_normalizer(patches) + if self.return_pos_enc: + return patches, pos_enc + return patches + + @property + def apply_cls_token(self): + return self.cls_token.enabled + + @property + def num_cls_tokens(self): + return self.cls_token.num_tokens + + @property + def num_cls_patches(self): + return self.cls_token.num_patches + + @property + def num_registers(self): + return self.cls_token.num_registers + + @property + def num_skip(self): + return self.num_cls_tokens + self.num_registers + + def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter): + if src_embed.shape != targ_embed.shape: + src_size = int(math.sqrt(src_embed.shape[1])) + + assert ( + src_size**2 == src_embed.shape[1] + ), "Unable to interpolate non-square embedding" + + src_embed = rearrange( + src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size + ) + src_embed = F.interpolate( + src_embed, + size=(self.num_rows, self.num_cols), + mode="bicubic", + align_corners=True, + antialias=False, + ) + src_embed = rearrange(src_embed, "b c h w -> b (h w) c") + targ_embed.data.copy_(src_embed) + + def _load_projection( + self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor + ): + if src_proj_weight.shape != targ_proj_weight.shape: + src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3)) + + assert (src_patch_size**2) * 3 == src_proj_weight.shape[ + 1 + ], "Unable to interpolate non-square patch size" + + src_proj_weight = rearrange( + src_proj_weight, + "b (c h w) -> b c h w", + c=3, + h=src_patch_size, + w=src_patch_size, + ) + src_proj_weight = F.interpolate( + src_proj_weight, + size=(self.patch_size, self.patch_size), + mode="bicubic", + align_corners=True, + antialias=False, + ) + src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)") + targ_proj_weight.data.copy_(src_proj_weight) + + def embed_patches(self, x: torch.Tensor) -> torch.Tensor: + patches = self.im_to_patches(x) + patches = self.embedder(patches) + return patches + + def apply_pos_enc( + self, + patches: torch.Tensor, + patch_idxs: torch.Tensor | None = None, + input_size: tuple[int, int] | None = None, + ) -> torch.Tensor: + if not self.abs_pos: + return patches + + pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size) + + if self.training and self.pos_dropout > 0: + keeps = ( + torch.rand( + patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device + ) + > self.pos_dropout + ) + pos_enc_drop = torch.where(keeps, pos_enc, 0) + else: + pos_enc_drop = pos_enc + + return patches + pos_enc_drop, pos_enc + + def get_pos_enc( + self, + batch_size: int, + patch_idxs: torch.Tensor | None = None, + input_size: tuple[int, int] | None = None, + ) -> torch.Tensor: + if input_size is None: + input_dims = self.input_dims + else: + input_dims = tuple(d // self.patch_size for d in input_size) + + pos_embed = self._get_pos_embeddings(batch_size, input_dims) + + if patch_idxs is None: + return pos_embed + + exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]) + + pos_embed = torch.gather( + pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs + ) + return pos_embed + + def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]): + if (self.num_rows, self.num_cols) == input_dims: + return self.pos_embed + + pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute( + 0, 3, 1, 2 + ) + + def window_select(pos_embed): + if input_dims[0] < pos_embed.shape[-2]: + pos_embed = pos_embed[..., : input_dims[0], :] + if input_dims[1] < pos_embed.shape[-1]: + pos_embed = pos_embed[..., :, : input_dims[1]] + return pos_embed + + if self.cpe_mode: + if self.training: + min_scale = math.sqrt(0.1) + scale = ( + torch.rand(batch_size, 1, 1, device=pos_embed.device) + * (1 - min_scale) + + min_scale + ) + aspect_min = math.log(3 / 4) + aspect_max = -aspect_min + aspect = torch.exp( + torch.rand(batch_size, 1, 1, device=pos_embed.device) + * (aspect_max - aspect_min) + + aspect_min + ) + + scale_x = scale * aspect + scale_y = scale * (1 / aspect) + scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1) + + pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * ( + 1 - scale_xy + ) + + lin_x = torch.linspace( + 0, 1, steps=input_dims[1], device=pos_embed.device + )[None, None].expand(batch_size, input_dims[0], -1) + lin_y = torch.linspace( + 0, 1, steps=input_dims[0], device=pos_embed.device + )[None, :, None].expand(batch_size, -1, input_dims[1]) + + lin_xy = torch.stack([lin_x, lin_y], dim=-1) + + grid_xy = lin_xy * scale_xy + pos_xy + + # Convert to [-1, 1] range + grid_xy.mul_(2).sub_(1) + + pos_embed = F.grid_sample( + pos_embed.float().expand(batch_size, -1, -1, -1), + grid=grid_xy, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ).to(pos_embed.dtype) + else: + max_dim = max(input_dims) + pos_embed = F.interpolate( + pos_embed.float(), + size=(max_dim, max_dim), + align_corners=True, + mode="bilinear", + ).to(pos_embed.dtype) + + pos_embed = window_select(pos_embed) + else: + pos_embed = window_select(pos_embed) + + if pos_embed.shape[-2:] != input_dims: + pos_embed = F.interpolate( + pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear" + ).to(pos_embed.dtype) + + pos_embed = pos_embed.flatten(2).permute(0, 2, 1) + + return pos_embed + + +class Im2Patches(nn.Module): + def __init__(self, patch_size: int): + super().__init__() + self.patch_size = patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.patch_size == 1: + patches = x.flatten(2) + patches = patches.permute(0, 2, 1) + return patches + + py = x.shape[-2] // self.patch_size + px = x.shape[-1] // self.patch_size + patches = rearrange( + x, + "b c (py yy) (px xx) -> b (py px) (c yy xx)", + py=py, + yy=self.patch_size, + px=px, + xx=self.patch_size, + ) + return patches + + +class ViTPatchLinear(nn.Linear): + def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory): + super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory) + self.patch_size = patch_size + + +class RadioInternVisionModel(nn.Module): + packed_modules_mapping = { + "qkv": ["qkv"], + } + + def __init__( + self, + config: PretrainedConfig = None, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + + self.config = config + self.img_size, self.grid_size, self.num_patches = self._init_img_size( + to_2tuple(config.patch_size), config.image_size + ) + max_img_size = int( + round(config.max_img_size / config.patch_size) * config.patch_size + ) + self.patch_generator = ViTPatchGenerator( + config.patch_size, + config.hidden_size, + input_dims=self.img_size, + max_input_dims=max_img_size, + cls_token=True, + register_multiple=config.reg_tokens, + ) + + self.encoder = InternVisionEncoder(config=config, quant_config=quant_config) + + def _init_img_size(self, patch_size, img_size: int | tuple[int, int]): + if img_size is None: + return None, None, None + img_size = to_2tuple(img_size) + grid_size = tuple([s // p for s, p in zip(img_size, patch_size)]) + num_patches = grid_size[0] * grid_size[1] + return img_size, grid_size, num_patches + + def get_input_embeddings(self): + return self.embeddings + + def forward(self, x: torch.Tensor) -> torch.FloatTensor: + assert self.patch_generator is not None + hidden_states = self.patch_generator(x) + encoder_outputs = self.encoder.forward(inputs_embeds=hidden_states) + assert isinstance(encoder_outputs, BaseModelOutput) + return encoder_outputs.last_hidden_state + + +class RadioModel(nn.Module): + packed_modules_mapping = { + "qkv": ["qkv"], + } + + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + + self.config = config + self.model = RadioInternVisionModel( + config=config, + quant_config=quant_config, + ) + + def forward( + self, + pixel_values: torch.Tensor | None = None, + pixel_embeds: torch.Tensor | None = None, + ) -> torch.FloatTensor: + y = self.model(pixel_values) + return self._extract_final(y) + + def load_weights(self, weights) -> set[str]: + remap_substrings = { + "attn": "attn.attn", + "qkv": "qkv_proj", + "blocks": "encoder.layers", + } + remap_prefixes = { + "radio_model.": "", + } + + loaded_params: set[str] = set() + params_dict = dict(self.named_parameters()) + + if isinstance(weights, dict): + weights_list = list(weights.items()) + else: + weights_list = list(weights) + + for name, weight in weights_list: + if not name.startswith("radio_model."): + # Skip non-radio weights + continue + name = replace_substrings(name, remap_substrings) + name = replace_prefix(name, remap_prefixes) + if name and name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + + return loaded_params + + def _extract_final(self, y: torch.Tensor): + # Remove CLS + REGISTERS tokens + patch_gen = getattr(self.model, "patch_generator", None) + if patch_gen is not None: + all_feat = y[:, patch_gen.num_skip :] + + return all_feat From b5b10bfb9ad56e62fd3a7423d82edf3a8632c6ba Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Tue, 18 Nov 2025 12:11:30 +0200 Subject: [PATCH 3/6] move internvl_utils --- benchmark/mmmu/bench_hf.py | 9 ++++- .../sglang/srt/multimodal}/internvl_utils.py | 39 ++++++++++++++----- 2 files changed, 37 insertions(+), 11 deletions(-) rename {benchmark/mmmu => python/sglang/srt/multimodal}/internvl_utils.py (78%) diff --git a/benchmark/mmmu/bench_hf.py b/benchmark/mmmu/bench_hf.py index 949b63b802a7..c841f44466d7 100644 --- a/benchmark/mmmu/bench_hf.py +++ b/benchmark/mmmu/bench_hf.py @@ -36,9 +36,10 @@ def eval_mmmu(args): try: # check if the model is belongs to internvl if "InternVL" in args.model_path: - from internvl_utils import load_image from transformers import AutoTokenizer + from sglang.srt.multimodal.internvl_utils import image_to_pixel_values + tokenizer = AutoTokenizer.from_pretrained(args.model_path) model = AutoModel.from_pretrained( args.model_path, @@ -80,7 +81,11 @@ def eval_mmmu(args): assert image is not None if "InternVL" in args.model_path: - pixel_values = load_image(sample["image_path"]).to(torch.bfloat16).cuda() + image = PIL.Image.open(sample["image_path"]).convert("RGB") + pixel_values = image_to_pixel_values( + image, input_size=448, max_num=12, use_thumbnail=True + ) + pixel_values = pixel_values.to(device="cuda", dtype=torch.bfloat16) contents = "" if prefix: contents += prefix diff --git a/benchmark/mmmu/internvl_utils.py b/python/sglang/srt/multimodal/internvl_utils.py similarity index 78% rename from benchmark/mmmu/internvl_utils.py rename to python/sglang/srt/multimodal/internvl_utils.py index 44c62c99aa41..0fbef1c7c048 100644 --- a/benchmark/mmmu/internvl_utils.py +++ b/python/sglang/srt/multimodal/internvl_utils.py @@ -8,14 +8,18 @@ IMAGENET_STD = (0.229, 0.224, 0.225) -def build_transform(input_size): - MEAN, STD = IMAGENET_MEAN, IMAGENET_STD +def build_transform( + input_size, + *, + mean: tuple[float, float, float], + std: tuple[float, float, float], +): transform = T.Compose( [ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), - T.Normalize(mean=MEAN, std=STD), + T.Normalize(mean=mean, std=std), ] ) return transform @@ -38,8 +42,13 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_ def dynamic_preprocess( - image, min_num=1, max_num=12, image_size=448, use_thumbnail=False -): + image: Image.Image, + *, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, +) -> list[Image.Image]: orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height @@ -83,12 +92,24 @@ def dynamic_preprocess( return processed_images -def load_image(image_file, input_size=448, max_num=12): - image = Image.open(image_file).convert("RGB") - transform = build_transform(input_size=input_size) +def image_to_pixel_values( + image: Image.Image, + *, + input_size: int, + min_num_tiles: int = 1, + max_num_tiles: int, + use_thumbnail: bool, + mean: tuple[float, float, float] = IMAGENET_MEAN, + std: tuple[float, float, float] = IMAGENET_STD, +) -> torch.Tensor: images = dynamic_preprocess( - image, image_size=input_size, use_thumbnail=True, max_num=max_num + image, + min_num=min_num_tiles, + max_num=max_num_tiles, + image_size=input_size, + use_thumbnail=use_thumbnail, ) + transform = build_transform(input_size, mean=mean, std=std) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values From f973053abda6c8182dc0f1d487c351fc8e65bb08 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Tue, 18 Nov 2025 12:13:11 +0200 Subject: [PATCH 4/6] support NemotronH_Nano_VL_V2 --- .../multimodal_language_models.md | 2 + python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/configs/nano_nemotron_vl.py | 106 ++++++++ .../sglang/srt/model_executor/model_runner.py | 3 + python/sglang/srt/models/nano_nemotron_vl.py | 251 ++++++++++++++++++ python/sglang/srt/models/nemotron_h.py | 9 +- .../multimodal/processors/nano_nemotron_vl.py | 195 ++++++++++++++ python/sglang/srt/utils/common.py | 24 ++ .../sglang/srt/utils/hf_transformers_utils.py | 4 + .../models/test_nvidia_nemotron_nano_v2_vl.py | 31 +++ test/srt/run_suite.py | 2 + test/srt/test_video_utils.py | 59 ++++ 13 files changed, 683 insertions(+), 6 deletions(-) create mode 100644 python/sglang/srt/configs/nano_nemotron_vl.py create mode 100644 python/sglang/srt/models/nano_nemotron_vl.py create mode 100644 python/sglang/srt/multimodal/processors/nano_nemotron_vl.py create mode 100644 test/srt/models/test_nvidia_nemotron_nano_v2_vl.py create mode 100644 test/srt/test_video_utils.py diff --git a/docs/supported_models/multimodal_language_models.md b/docs/supported_models/multimodal_language_models.md index 3414d6c48d3a..90c877518de0 100644 --- a/docs/supported_models/multimodal_language_models.md +++ b/docs/supported_models/multimodal_language_models.md @@ -45,6 +45,7 @@ in the GitHub search bar. | **DotsVLM** (General/OCR) | `rednote-hilab/dots.vlm1.inst` | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. | | | **DotsVLM-OCR** | `rednote-hilab/dots.ocr` | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` | | **NVILA** (8B, 15B, Lite-2B, Lite-8B, Lite-15B) | `Efficient-Large-Model/NVILA-8B` | `chatml` | NVILA explores the full stack efficiency of multi-modal design, achieving cheaper training, faster deployment and better performance. | +| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | NVIDIA Nemotron Nano v2 VL enables multi-image reasoning and video understanding, along with strong document intelligence, visual Q&A and summarization capabilities. It builds on Nemotron Nano V2, a hybrid Mamba-Transformer LLM, in order to achieve higher inference throughput in long document and video scenarios. | Use `--trust-remote-code`. You may need to adjust `--max-mamba-cache-size` [default is 512] to fit memory constraints. | | **JetVLM** | | JetVLM is an vision-language model designed for high-performance multimodal understanding and generation tasks built upon Jet-Nemotron. | Coming soon | ## Video Input Support @@ -57,6 +58,7 @@ SGLang supports video input for Vision-Language Models (VLMs), enabling temporal | **GLM-4v** (4.5V, 4.1V, MOE) | `zai-org/GLM-4.5V` | Video clips are read with Decord, converted to tensors, and passed to the model alongside metadata for rotary-position handling. | | **NVILA** (Full & Lite) | `Efficient-Large-Model/NVILA-8B` | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. | | **LLaVA video variants** (LLaVA-NeXT-Video, LLaVA-OneVision) | `lmms-lab/LLaVA-NeXT-Video-7B` | The processor routes video prompts to the LlavaVid video-enabled architecture, and the provided example shows how to query it with `sgl.video(...)` clips. | +| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | For video, the processor is configured to sample at 2 FPS, at a max of 128 frames, as per model training. | | **JetVLM** | | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. | Use `sgl.video(path, num_frames)` when building prompts to attach clips from your SGLang programs. diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 623d91d7352f..b35cc1dc5f23 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -12,6 +12,7 @@ from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.longcat_flash import LongcatFlashConfig +from sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config from sglang.srt.configs.nemotron_h import NemotronHConfig from sglang.srt.configs.olmo3 import Olmo3Config from sglang.srt.configs.qwen3_next import Qwen3NextConfig @@ -40,6 +41,7 @@ "DotsOCRConfig", "FalconH1Config", "NemotronHConfig", + "NemotronH_Nano_VL_V2_Config", "JetNemotronConfig", "JetVLMConfig", ] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 77c2d81cde89..d04dc07280e8 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -937,6 +937,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Mistral3ForConditionalGeneration", "MultiModalityCausalLM", "MllamaForConditionalGeneration", + "NemotronH_Nano_VL_V2", "Qwen2AudioForConditionalGeneration", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration", diff --git a/python/sglang/srt/configs/nano_nemotron_vl.py b/python/sglang/srt/configs/nano_nemotron_vl.py new file mode 100644 index 000000000000..79e63edd409f --- /dev/null +++ b/python/sglang/srt/configs/nano_nemotron_vl.py @@ -0,0 +1,106 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Adapted from https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/blob/cb5a65ff10232128389d882d805fa609427544f1/configuration.py + +from transformers.configuration_utils import PretrainedConfig +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.utils import logging + +from sglang.srt.configs.nemotron_h import NemotronHConfig +from sglang.srt.multimodal.internvl_utils import IMAGENET_MEAN, IMAGENET_STD + +logger = logging.get_logger(__name__) + + +class NemotronH_Nano_VL_V2_Config(PretrainedConfig): + model_type = "NemotronH_Nano_VL_V2" + is_composition = True + + def __init__( + self, + vision_config=None, + llm_config=None, + force_image_size: int = 512, + patch_size: int = 16, + downsample_ratio=0.5, + template=None, + ps_version="v2", + image_tag_type="internvl", + projector_hidden_size=4096, + vit_hidden_size=1280, + attn_implementation="flash_attention_2", + video_pruning_rate: float = 0.0, + video_context_token: str = "