diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 792018669148..17ba7bcfbc23 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -486,13 +486,12 @@ def __init__( customized_position_embedding_applier: Callable[ [torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor] ] = None, + use_data_parallel: bool = False, **kwargs, ): super().__init__() - attn_tp_rank = get_attention_tp_rank() - attn_tp_size = get_attention_tp_size() - self.tp_size = attn_tp_size - self.tp_rank = attn_tp_rank + self.tp_size = 1 if use_data_parallel else get_attention_tp_size() + self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank() self.dropout = dropout self.head_size = embed_dim // num_heads self.hidden_size_per_attention_head = dist_utils.divide( diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index edcad66830c6..63c0a42f2a44 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -40,6 +40,10 @@ Qwen2_5_VisionRotaryEmbedding, ) +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.distributed.parallel_state import get_pp_group from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm @@ -62,6 +66,8 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.utils import permute_inv +from sglang.srt.multimodal.mm_utils import run_dp_sharded_mrope_vision_model +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) @@ -76,14 +82,21 @@ def __init__( hidden_act="silu", quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), + tp_size=self.tp_size, + tp_rank=self.tp_rank, ) self.down_proj = RowParallelLinear( hidden_features, @@ -91,6 +104,8 @@ def __init__( bias=bias, quant_config=quant_config, prefix=add_prefix("down_proj", prefix), + tp_size=self.tp_size, + tp_rank=self.tp_rank, ) self.act = ACT2FN[hidden_act] @@ -115,6 +130,7 @@ def __init__( prefix: str = "", num_dummy_heads: int = 0, rms_norm_eps: float = 1e-6, + use_data_parallel: bool = False, ) -> None: super().__init__() self.norm1 = RMSNorm(dim, eps=rms_norm_eps) @@ -130,6 +146,7 @@ def __init__( quant_config=quant_config, prefix=add_prefix("attn", prefix), num_dummy_heads=num_dummy_heads, + use_data_parallel=use_data_parallel, ) self.mlp = Qwen2_5_VLMLP( dim, @@ -137,6 +154,7 @@ def __init__( hidden_act=hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), + use_data_parallel=use_data_parallel, ) def forward( @@ -180,10 +198,13 @@ def __init__( spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = RMSNorm(context_dim, eps=1e-6) + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() + tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() self.mlp = nn.ModuleList( [ ColumnParallelLinear( @@ -192,6 +213,8 @@ def __init__( bias=True, quant_config=quant_config, prefix=add_prefix("mlp.0", prefix), + tp_size=tp_size, + tp_rank=tp_rank, ), nn.GELU(), RowParallelLinear( @@ -200,6 +223,8 @@ def __init__( bias=True, quant_config=quant_config, prefix=add_prefix("mlp.2", prefix), + tp_size=tp_size, + tp_rank=tp_rank, ), ] ) @@ -225,6 +250,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -241,6 +267,8 @@ def __init__( self.window_size = vision_config.window_size self.patch_size = vision_config.patch_size mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8 + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.out_hidden_size self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, @@ -261,6 +289,7 @@ def __init__( norm_layer=norm_layer, quant_config=quant_config, prefix=add_prefix(f"blocks.{i}", prefix), + use_data_parallel=use_data_parallel, ) for i in range(depth) ] @@ -271,6 +300,7 @@ def __init__( spatial_merge_size=spatial_merge_size, quant_config=quant_config, prefix=add_prefix("merger", prefix), + use_data_parallel=use_data_parallel, ) def get_window_index(self, grid_thw): @@ -461,6 +491,7 @@ def __init__( self.pp_group = get_pp_group() self.config = config + self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), @@ -468,6 +499,7 @@ def __init__( # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. quant_config=quant_config, prefix=add_prefix("visual", prefix), + use_data_parallel=self.use_data_parallel, ) self.model = Qwen2Model( @@ -510,7 +542,12 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) assert pixel_values.dim() == 2, pixel_values.dim() assert image_grid_thw.dim() == 2, image_grid_thw.dim() - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, image_grid_thw.tolist(), rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) return image_embeds def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: @@ -521,7 +558,12 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) assert pixel_values.dim() == 2, pixel_values.dim() assert video_grid_thw.dim() == 2, video_grid_thw.dim() - video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d" + ) + else: + video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) return video_embeds def get_input_embeddings(self): diff --git a/python/sglang/srt/multimodal/mm_utils.py b/python/sglang/srt/multimodal/mm_utils.py index c399be806183..12ed1893436f 100644 --- a/python/sglang/srt/multimodal/mm_utils.py +++ b/python/sglang/srt/multimodal/mm_utils.py @@ -28,14 +28,22 @@ """ import ast +import itertools import math import re from io import BytesIO +from typing import Literal import numpy as np import pybase64 +import torch from PIL import Image +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.distributed.communication_op import tensor_model_parallel_all_gather from sglang.srt.utils import flatten_nested_list @@ -347,3 +355,263 @@ def process_images(images, image_processor, model_cfg): if all(x.shape == new_images[0].shape for x in new_images): new_images = np.stack(new_images, axis=0) return new_images + + +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/vision.py +def get_dp_encoder_lb_assignment( + sizes: list[int], + num_gpus: int = 2, +) -> tuple[list[int], list[int], list[int]]: + """ + Generate load balancing assignment and metadata + for distributing data across GPUs. + The load is determined by the total image sizes, + not the number of images. + + Args: + sizes: The size of each image + num_gpus: Number of GPUs to balance across + + Returns: + shuffle_indices: + Indices to reorder data for balanced loading + gpu_sample_counts: + Number of samples assigned to each GPU + grouped_sizes_per_gpu: + Total size assigned to each GPU + + Example: + ``` + sizes = [1000, 100, 200, 50] + num_gpus = 2 + ``` + + """ + + n_samples = len(sizes) + + # Handle edge cases + if n_samples == 0: + return [], [0] * num_gpus, [0] * num_gpus + + # Use greedy algorithm - balance by total size, not sample count + gpu_assignments = [list[int]() for _ in range(num_gpus)] + gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count + + # Sort indices by size (largest first for better load balancing) + # sizes = [1000, 100, 200, 50] + # large_to_small_indices = [0, 2, 1, 3] + large_to_small_indices = sorted( + range(n_samples), key=lambda i: sizes[i], reverse=True + ) + + for idx in large_to_small_indices: + # Find GPU with minimum current load (by total size) + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + gpu_assignments[min_gpu].append(idx) + gpu_loads[min_gpu] += sizes[idx] + + # Create shuffle indices and counts + shuffle_indices = list[int]() + gpu_sample_counts = list[int]() + for gpu_id in range(num_gpus): + # GPU_0 = [1000] = [0] + # GPU_1 = [200, 100, 50] = [2, 1, 3] + # shuffle_indices = [0, 2, 1, 3] + shuffle_indices.extend(gpu_assignments[gpu_id]) + # GPU_0 = [1] + # GPU_1 = [3] + # gpu_sample_counts = [1, 3] + gpu_sample_counts.append(len(gpu_assignments[gpu_id])) + + return (shuffle_indices, gpu_sample_counts, gpu_loads) + + +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/vision.py +def run_dp_sharded_mrope_vision_model( + vision_model: torch.nn.Module, + pixel_values: torch.Tensor, + grid_thw_list: list, + *, + rope_type: Literal["rope_3d", "rope_2d"], +): + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the + first dimension and run the vision model. + This function is used to run the vision model with mrope. + + Args: + vision_model (torch.nn.Module): Vision model. + pixel_values (torch.Tensor): Image/Video input tensor. + grid_thw_list: List of grid dimensions for each image + rope_type: Type of rope used in the vision model. + Different rope types have different dimension to do ViT. + "rope_3d" for 3D rope (e.g., Qwen2.5-VL) + "rope_2d" for 2D rope (e.g., Kimi-VL) + Returns: + torch.Tensor: Output image embeddings + + Example: + ``` + vision_model.out_hidden_size = 64 + vision_model.spatial_merge_size = 2 + pixel_values.shape = (1350, channel) + grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] + tp_size = 2 + ``` + + """ + tp_size = get_tensor_model_parallel_world_size() + + # GPU_0 tp_rank_local = 0 + # GPU_1 tp_rank_local = 1 + tp_rank_local = get_tensor_model_parallel_rank() + + # patches_per_image = [1000, 100, 200, 50] + patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] + # print(f"{patches_per_image = }") + # patches_per_image = [0, 1000, 1100, 1300, 1350] + cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] + + # Get load balancing assignment with all metadata + # image_to_tp_rank = [0, 2, 1, 3] + # gpu_sample_counts = [1, 3] + # grouped_pixel_values_len = [1000, 350] + (image_to_tp_rank, gpu_sample_counts, grouped_pixel_values_len) = ( + get_dp_encoder_lb_assignment(patches_per_image, tp_size) + ) + + # cu_gpu_sample_counts = [0, 1, 4] + cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] + + # GPU_0 image_idxs_local = [0] + # GPU_1 image_idxs_local = [2, 1, 3] + image_idxs_local = image_to_tp_rank[ + cum_gpu_sample_counts[tp_rank_local] : cum_gpu_sample_counts[tp_rank_local + 1] + ] + + # Get the pixel values for the local images based on the image_idxs_local + if len(image_idxs_local) > 0: + pixel_values_local = torch.cat( + [ + pixel_values[cum_patches_per_image[i] : cum_patches_per_image[i + 1]] + for i in image_idxs_local + ] + ) + else: + # Handle case where this rank has no images + pixel_values_local = torch.empty( + (0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + # embed_dim_reduction_factor = 2 * 2 + if rope_type == "rope_2d": + embed_dim_reduction_factor = ( + vision_model.merge_kernel_size[0] * vision_model.merge_kernel_size[1] + ) + else: + embed_dim_reduction_factor = ( + vision_model.spatial_merge_size * vision_model.spatial_merge_size + ) + + # Find the max length across all ranks + # The output embedding of every DP rank has to be + # padded to this length for tensor_model_parallel_all_gather + # to work + max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor + local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] + + # Run the vision model on the local pixel_values_local + if rope_type == "rope_2d": + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model( + pixel_values_local, torch.tensor(local_grid_thw_list) + ) + if isinstance(image_embeds_local, list): + image_embeds_local = torch.cat(image_embeds_local, dim=0) + else: + out_dim = getattr(vision_model.config, "hidden_size", None) + image_embeds_local = torch.empty( + (0, embed_dim_reduction_factor, out_dim), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + else: + if pixel_values_local.shape[0] > 0: + # print(f"{local_grid_thw_list = }", flush=True) + image_embeds_local = vision_model( + pixel_values_local, torch.tensor(local_grid_thw_list) + ) + else: + # Handle empty case + image_embeds_local = torch.empty( + (0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + + # Pad the output based on max_len_per_rank + # for tensor_model_parallel_all_gather to work + current_len = image_embeds_local.shape[0] + if current_len < max_len_per_rank: + padding_size = max_len_per_rank - current_len + if rope_type == "rope_2d": + padding = torch.empty( + ( + padding_size, + image_embeds_local.shape[1], + image_embeds_local.shape[2], + ), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device, + ) + else: + padding = torch.empty( + (padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device, + ) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0) + else: + image_embeds_local_padded = image_embeds_local + + # Do all_gather to collect embeddings from all ranks + gathered_embeds = tensor_model_parallel_all_gather(image_embeds_local_padded, dim=0) + + # Remove padding and reconstruct per-rank embeddings + rank_embeddings = list[torch.Tensor]() + for rank in range(tp_size): + start_idx = rank * max_len_per_rank + end_idx = start_idx + ( + grouped_pixel_values_len[rank] // embed_dim_reduction_factor + ) + rank_embeddings.append(gathered_embeds[start_idx:end_idx]) + + patches_per_output_image = [ + (patch_size // embed_dim_reduction_factor) for patch_size in patches_per_image + ] + + # Reconstruct embeddings in the original order + original_order_embeddings = [None] * len(grid_thw_list) + current_idx = 0 + for rank in range(tp_size): + count = gpu_sample_counts[rank] + if count > 0: + # Get images assigned to this rank in shuffled order + # GPU_0 = image_idxs_local [0] + # GPU_1 = image_idxs_local [2, 1, 3] + rank_images = image_to_tp_rank[current_idx : current_idx + count] + + rank_embed = rank_embeddings[rank] + # Split rank embeddings back to individual images + embed_start = 0 + for img_idx in rank_images: + img_patches = patches_per_output_image[img_idx] + original_order_embeddings[img_idx] = rank_embed[ + embed_start : embed_start + img_patches + ] + embed_start += img_patches + current_idx += count + out_embeddings = torch.cat(original_order_embeddings, dim=0) + return out_embeddings diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e798e01e5220..b713c23c7e45 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -578,6 +578,9 @@ class ServerArgs: decrypted_config_file: Optional[str] = None decrypted_draft_config_file: Optional[str] = None + # For encoder dp + mm_enable_dp_encoder: bool = False + # For forward hooks hooks: Optional[List[dict[str, Any]]] = None @@ -3728,6 +3731,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.decrypted_draft_config_file, help="The path of the decrypted draft config file.", ) + parser.add_argument( + "--mm-enable-dp-encoder", + action="store_true", + default=ServerArgs.mm_enable_dp_encoder, + help="Enabling data parallelism for mm encoder. The dp size will be set to the tp size automatically.", + ) # For registering hooks parser.add_argument( diff --git a/test/srt/nightly/test_encoder_dp.py b/test/srt/nightly/test_encoder_dp.py new file mode 100644 index 000000000000..cb47634cf41b --- /dev/null +++ b/test/srt/nightly/test_encoder_dp.py @@ -0,0 +1,270 @@ +import argparse +import glob +import json +import os +import random +import subprocess +import sys +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + +MODELS = [ + SimpleNamespace(model="Qwen/Qwen2.5-VL-72B-Instruct", mmmu_accuracy=0.55), +] + + +# Set default mem_fraction_static to 0.8 +DEFAULT_MEM_FRACTION_STATIC = 0.8 + + +class TestVLMEncoderDP(CustomTestCase): + parsed_args = None # Class variable to store args + + @classmethod + def setUpClass(cls): + # Removed argument parsing from here + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + + if cls.parsed_args is None: + cls.parsed_args = SimpleNamespace( + mem_fraction_static=DEFAULT_MEM_FRACTION_STATIC + ) + + # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work. + os.environ["OPENAI_API_KEY"] = cls.api_key + os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1" + + def run_mmmu_eval( + self, + model_version: str, + output_path: str, + *, + env: dict | None = None, + ): + """ + Evaluate a VLM on the MMMU validation set with lmms‑eval. + Only `model_version` (checkpoint) and `chat_template` vary; + We are focusing only on the validation set due to resource constraints. + """ + # -------- fixed settings -------- + model = "openai_compatible" + tp = 1 + tasks = "mmmu_val" + batch_size = 32 + log_suffix = "openai_compatible" + os.makedirs(output_path, exist_ok=True) + + # -------- compose --model_args -------- + model_args = f'model_version="{model_version}",' f"tp={tp}" + + # -------- build command list -------- + cmd = [ + "python3", + "-m", + "lmms_eval", + "--model", + model, + "--model_args", + model_args, + "--tasks", + tasks, + "--batch_size", + str(batch_size), + "--log_samples", + "--log_samples_suffix", + log_suffix, + "--output_path", + str(output_path), + ] + + subprocess.run( + cmd, + check=True, + timeout=3600, + ) + + def _run_vlm_mmmu_test( + self, + model, + output_path, + test_name="", + custom_env=None, + log_level="info", + capture_output=False, + ): + """ + Common method to run VLM MMMU benchmark test. + + Args: + model: Model to test + output_path: Path for output logs + test_name: Optional test name for logging + custom_env: Optional custom environment variables + log_level: Log level for server (default: "info") + capture_output: Whether to capture server stdout/stderr + """ + print(f"\nTesting model: {model.model}{test_name}") + + process = None + mmmu_accuracy = 0 # Initialize to handle potential exceptions + server_output = "" + + try: + # Prepare environment variables + process_env = os.environ.copy() + if custom_env: + process_env.update(custom_env) + # if test vlm with cuda_ipc feature, open this env_var + process_env["SGLANG_USE_CUDA_IPC_TRANSPORT"] = "1" + + # Prepare stdout/stderr redirection if needed + stdout_file = None + stderr_file = None + if capture_output: + stdout_file = open("/tmp/server_stdout.log", "w") + stderr_file = open("/tmp/server_stderr.log", "w") + + # Launch server for testing + process = popen_launch_server( + model.model, + base_url=self.base_url, + timeout=self.time_out, + api_key=self.api_key, + other_args=[ + "--trust-remote-code", + "--cuda-graph-max-bs", + "32", + "--mm-enable-dp-encoder", + "--tp=4", + "--mem-fraction-static", + str(self.parsed_args.mem_fraction_static), # Use class variable + "--log-level", + log_level, + ], + env=process_env, + return_stdout_stderr=( + (stdout_file, stderr_file) if capture_output else None + ), + ) + + # Run evaluation + self.run_mmmu_eval(model.model, output_path) + + # Get the result file + # Search recursively for JSON result files (lmms-eval v0.4.1+ creates subdirectories) + result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) + if not result_files: + result_files = glob.glob(f"{output_path}/*.json") + + if not result_files: + raise FileNotFoundError(f"No JSON result files found in {output_path}") + + result_file_path = result_files[0] + + with open(result_file_path, "r") as f: + result = json.load(f) + print(f"Result{test_name}\n: {result}") + + # Process the result + mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] + print( + f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}" + ) + + # Capture server output if requested + if capture_output and process: + server_output = self._read_output_from_files() + + # Assert performance meets expected threshold + self.assertGreaterEqual( + mmmu_accuracy, + model.mmmu_accuracy, + f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}", + ) + + return server_output + + except Exception as e: + print(f"Error testing {model.model}{test_name}: {e}") + self.fail(f"Test failed for {model.model}{test_name}: {e}") + + finally: + # Ensure process cleanup happens regardless of success/failure + if process is not None and process.poll() is None: + print(f"Cleaning up process {process.pid}") + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + # clean up temporary files + if capture_output: + if stdout_file: + stdout_file.close() + if stderr_file: + stderr_file.close() + for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]: + try: + if os.path.exists(filename): + os.remove(filename) + except Exception as e: + print(f"Error removing {filename}: {e}") + + def _read_output_from_files(self): + output_lines = [] + + log_files = [ + ("/tmp/server_stdout.log", "[STDOUT]"), + ("/tmp/server_stderr.log", "[STDERR]"), + ] + for filename, tag in log_files: + try: + if os.path.exists(filename): + with open(filename, "r") as f: + for line in f: + output_lines.append(f"{tag} {line.rstrip()}") + except Exception as e: + print(f"Error reading {tag.lower()} file: {e}") + + return "\n".join(output_lines) + + def test_vlm_mmmu_benchmark(self): + """Test VLM models against MMMU benchmark.""" + models_to_test = MODELS + + if is_in_ci(): + models_to_test = [random.choice(MODELS)] + + for model in models_to_test: + self._run_vlm_mmmu_test(model, "./logs") + + +if __name__ == "__main__": + # Define and parse arguments here, before unittest.main + parser = argparse.ArgumentParser(description="Test VLM models") + parser.add_argument( + "--mem-fraction-static", + type=float, + help="Static memory fraction for the model", + default=DEFAULT_MEM_FRACTION_STATIC, + ) + + # Parse args intended for unittest + args = parser.parse_args() + + # Store the parsed args object on the class + TestVLMEncoderDP.parsed_args = args + + # Pass args to unittest + unittest.main(argv=[sys.argv[0]]) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a83a0df20bb1..add4ea6794c6 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -220,6 +220,7 @@ TestFile("test_deepseek_r1_fp8_trtllm_backend.py", 3600), ], "nightly-4-gpu": [ + TestFile("nightly/test_encoder_dp.py", 500), TestFile("test_qwen3_next_deterministic.py", 200), ], "nightly-8-gpu": [],