diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index ea964a54383c..a028c668c8ab 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import math import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -20,6 +21,8 @@ from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, + get_load_balance_assignment, + run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model) from vllm.platforms import current_platform from vllm.utils import get_open_port, update_environment_variables @@ -425,8 +428,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, # Set random seed for reproducibility current_platform.seed_everything(0) - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) torch.set_default_device(device) update_environment_variables({ @@ -463,3 +466,322 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, # Check that the outputs are close (they should be identical) assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," + "expected_grouped_sizes_per_gpu,test_description", + [ + # Empty input + ([], 2, [], [0, 0], [0, 0], "empty input"), + + # Fewer samples than GPUs + ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 + ], "fewer samples than GPUs"), + + # Single GPU + ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), + + # Balanced assignment + ([100, 100, 100, 100 + ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), + + # Unbalanced sizes - this one is trickier since the algorithm is greedy + ([1000, 100, 200, 50], 2, [0, 2, 1, 3 + ], [1, 3], [1000, 350], "unbalanced sizes"), + ], +) +def test_get_load_balance_assignment_cases(sizes, num_gpus, + expected_shuffle_indices, + expected_gpu_sample_counts, + expected_grouped_sizes_per_gpu, + test_description): + """Test get_load_balance_assignment with various input cases.""" + result = get_load_balance_assignment(sizes, num_gpus=num_gpus) + (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result + + # Common assertions for all cases + assert len(shuffle_indices) == len(sizes) + assert len(gpu_sample_counts) == num_gpus + assert len(grouped_sizes_per_gpu) == num_gpus + assert sum(gpu_sample_counts) == len(sizes) + + assert shuffle_indices == expected_shuffle_indices + + assert gpu_sample_counts == expected_gpu_sample_counts + assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu + + +class SimpleMRopeVisionModel(torch.nn.Module): + """A simple vision model for testing mrope functionality.""" + + def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): + super().__init__() + self.spatial_merge_size = spatial_merge_size + self.out_hidden_size = out_hidden_size + self.linear = torch.nn.Linear(768, out_hidden_size) + + def forward(self, pixel_values: torch.Tensor, + grid_thw_list: list[list[int]]): + """Simple forward pass that simulates spatial merging.""" + # Apply linear transformation + embeddings = self.linear(pixel_values) + + # Simulate spatial merging by reducing the number of patches + merge_factor = self.spatial_merge_size * self.spatial_merge_size + + # Group patches and merge spatially + merged_embeddings = [] + start_idx = 0 + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + end_idx = start_idx + num_patches + + # Get patches for this image + image_patches = embeddings[start_idx:end_idx] + + # Simulate spatial merging by averaging groups of patches + merged_patches = num_patches // merge_factor + if merged_patches > 0: + # Reshape and average to simulate merging + reshaped = image_patches[:merged_patches * merge_factor].view( + merged_patches, merge_factor, -1) + merged = reshaped.mean(dim=1) + merged_embeddings.append(merged) + + start_idx = end_idx + + if merged_embeddings: + return torch.cat(merged_embeddings, dim=0) + else: + return torch.empty((0, self.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 3, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_mrope_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_mrope_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, + world_size: int, + batch_size: int, + master_port: int): + """ + Test that run_dp_sharded_mrope_vision_model produces the same results as + calling the model directly. + """ + # Set random seed for reproducibility + current_platform.seed_everything(0) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test data + grid_thw_list = [] + pixel_values_list = [] + + for i in range(batch_size): + # Varying image sizes for better testing + t, h, w = 1, 4 + i, 4 + i + grid_thw_list.append([t, h, w]) + + num_patches = t * h * w + # Create random pixel values for this image + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + # Concatenate all pixel values + pixel_values = torch.cat(pixel_values_list, dim=0) + + # Create a simple mrope vision model + vision_model = SimpleMRopeVisionModel() + + # Run the model directly on the full input (only on rank 0) + if local_rank == 0: + with torch.inference_mode(): + direct_output = vision_model(pixel_values, grid_thw_list) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list) + sharded_output = torch.cat(sharded_output, dim=0) + + # Check that the world size is setup correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Compare outputs (only on rank 0) + if local_rank == 0: + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, + sharded_output, + rtol=1e-5, + atol=1e-5) + + +@multi_gpu_test(num_gpus=2) +def test_run_dp_sharded_mrope_vision_model_empty_input(): + world_size = 2 + mp.spawn( + run_dp_sharded_mrope_vision_model_empty_input_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_empty_input_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with empty input.""" + # Set up distributed environment + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create empty inputs + pixel_values = torch.empty((0, 768)) + grid_thw_list: list[list[int]] = [] + + vision_model = SimpleMRopeVisionModel() + + # Should handle empty input gracefully + with torch.inference_mode(): + output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values, + grid_thw_list) + + assert len(output) == 0 + + +@multi_gpu_test(num_gpus=4) +def test_run_dp_sharded_mrope_vision_model_uneven_load(): + world_size = 4 + mp.spawn( + run_dp_sharded_mrope_vision_model_uneven_load_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_uneven_load_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" + # Set up distributed environment + current_platform.seed_everything(123) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create images with very different sizes + grid_thw_list = [ + [1, 2, 2], # Small: 4 patches + [1, 8, 8], # Large: 64 patches + [1, 3, 3], # Medium: 9 patches + ] + + pixel_values_list = [] + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel() + + # Should handle uneven distribution without errors + with torch.inference_mode(): + output_tuple = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list) + + # Verify output shape is reasonable + merge_factor = vision_model.spatial_merge_size**2 + expected_output_patches = list( + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) + + for i, output in enumerate(output_tuple): + assert output.shape[0] == expected_output_patches[i] + assert output.shape[1] == vision_model.out_hidden_size + + +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): + """Test SimpleMRopeVisionModel with different spatial merge sizes.""" + device = current_platform.device_type + + grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images + pixel_values_list = [] + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768, device=device) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel( + spatial_merge_size=spatial_merge_size).to(device) + + with torch.inference_mode(): + output = vision_model(pixel_values, grid_thw_list) + + # Verify output dimensions based on spatial merging + total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) + merge_factor = spatial_merge_size**2 + expected_output_patches = total_patches // merge_factor + + assert output.shape[0] == expected_output_patches + assert output.shape[1] == vision_model.out_hidden_size diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 671ad9eed234..d3b6b2089f42 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -437,7 +437,7 @@ def weight_loader(self, shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] - param[shard_offset:shard_offset + shard_size] = loaded_weight + param.data[shard_offset:shard_offset + shard_size] = loaded_weight @CustomOp.register("column_parallel_linear") diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 5bcbcc4f0e37..34eec10296b5 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -45,10 +45,14 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm +# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + MergedReplicatedLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) +# yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -57,6 +61,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -170,19 +175,25 @@ def __init__(self, bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( + cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else + MergedColumnParallelLinear) + self.gate_up_proj = cls_gate_up_proj( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + + cls_down_proj = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.down_proj = cls_down_proj(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -220,28 +231,42 @@ def __init__( projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = (1 if use_data_parallel else + parallel_state.get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + if use_data_parallel: + self.qkv = ReplicatedLinear(embed_dim, + self.hidden_size_per_attention_head * + 3 * num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + + else: + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + + cls_proj = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.proj = cls_proj(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -302,8 +327,6 @@ def forward( k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: @@ -370,23 +393,27 @@ def __init__( norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) - self.attn = Qwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = Qwen2_5_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) def forward( self, @@ -445,24 +472,30 @@ def __init__( spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) + + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) self.mlp = nn.ModuleList([ - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + cls_fc1(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), nn.GELU(), - RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), + cls_fc2(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -514,6 +547,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -523,6 +557,8 @@ def __init__( depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.out_hidden_size # args for get_window_index_thw self.window_size = vision_config.window_size @@ -550,7 +586,8 @@ def __init__( vision_config.hidden_act), norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(depth) ]) self.merger = Qwen2_5_VisionPatchMerger( @@ -560,6 +597,7 @@ def __init__( spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -767,7 +805,6 @@ def load_weights(self, weights: Iterable[tuple[str, if weight_name not in name: continue name = name.replace(weight_name, param_name) - param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -840,6 +877,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = (vllm_config.parallel_config. + enable_multimodal_encoder_data_parallel) self.config = config self.multimodal_config = multimodal_config @@ -851,6 +890,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=self._maybe_ignore_quant_config( self.quant_config), prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) else: self.visual = None @@ -973,7 +1013,13 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list) + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync @@ -995,8 +1041,12 @@ def _process_video_input( video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list) + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 3361878a20d8..2315fe2ab92b 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -329,8 +329,6 @@ def forward( k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 99f3db25a71d..58c71d865dc7 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -3,6 +3,8 @@ import asyncio import atexit +import itertools +import math from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from itertools import groupby @@ -465,6 +467,219 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, return vision_embeddings +def get_load_balance_assignment( + sizes: list[int], + num_gpus: int = 2, +) -> tuple[list[int], list[int], list[int]]: + """ + Generate load balancing assignment and metadata + for distributing data across GPUs. + The load is determined by the total image sizes, + not the number of images. + + Args: + sizes: The size of each image + num_gpus: Number of GPUs to balance across + + Returns: + shuffle_indices: + Indices to reorder data for balanced loading + gpu_sample_counts: + Number of samples assigned to each GPU + grouped_sizes_per_gpu: + Total size assigned to each GPU + + Example: + ``` + sizes = [1000, 100, 200, 50] + num_gpus=2 + ``` + + """ + + n_samples = len(sizes) + + # Handle edge cases + if n_samples == 0: + return [], [0] * num_gpus, [0] * num_gpus + + # Use greedy algorithm - balance by total size, not sample count + gpu_assignments = [list[int]() for _ in range(num_gpus)] + gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count + + # Sort indices by size (largest first for better load balancing) + # sizes = [1000, 100, 200, 50] + # large_to_small_indices = [0, 2, 1, 3] + large_to_small_indices = sorted(range(n_samples), + key=lambda i: sizes[i], + reverse=True) + + for idx in large_to_small_indices: + # Find GPU with minimum current load (by total size) + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + gpu_assignments[min_gpu].append(idx) + gpu_loads[min_gpu] += sizes[idx] + + # Create shuffle indices and counts + shuffle_indices = list[int]() + gpu_sample_counts = list[int]() + for gpu_id in range(num_gpus): + # GPU_0 = [1000] = [0] + # GPU_1 = [200, 100, 50] = [2, 1, 3] + # shuffle_indices = [0, 2, 1, 3] + shuffle_indices.extend(gpu_assignments[gpu_id]) + # GPU_0 = [1] + # GPU_1 = [3] + # gpu_sample_counts = [1, 3] + gpu_sample_counts.append(len(gpu_assignments[gpu_id])) + + return (shuffle_indices, gpu_sample_counts, gpu_loads) + + +def run_dp_sharded_mrope_vision_model( + vision_model: torch.nn.Module, + pixel_values: torch.Tensor, + grid_thw_list: list[list[int]], +) -> tuple[torch.Tensor, ...]: + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the + first dimension and run the vision model. + This function is used to run the vision model with mrope. + + Args: + vision_model (torch.nn.Module): Vision model. + pixel_values (torch.Tensor): Image/Video input tensor. + grid_thw_list: List of grid dimensions for each image + Returns: + torch.Tensor: Output image embeddings + + Example: + ``` + vision_model.out_hidden_size = 64 + vision_model.spatial_merge_size = 2 + pixel_values.shape = (1350, channel) + grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] + tp_size=2 + ``` + + """ + tp_size = get_tensor_model_parallel_world_size() + + # GPU_0 tp_rank_local = 0 + # GPU_1 tp_rank_local = 1 + tp_rank_local = get_tensor_model_parallel_rank() + + # patches_per_image = [1000, 100, 200, 50] + patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] + # patches_per_image = [0, 1000, 1100, 1300, 1350] + cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] + + # Get load balancing assignment with all metadata + # image_to_tp_rank = [0, 2, 1, 3] + # gpu_sample_counts = [1, 3] + # grouped_pixel_values_len = [1000, 350] + (image_to_tp_rank, gpu_sample_counts, + grouped_pixel_values_len) = get_load_balance_assignment( + patches_per_image, tp_size) + + # cu_gpu_sample_counts = [0, 1, 4] + cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] + + # GPU_0 image_idxs_local = [0] + # GPU_1 image_idxs_local = [2, 1, 3] + image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: + cum_gpu_sample_counts[tp_rank_local + + 1]] + + # Get the pixel values for the local images based on the image_idxs_local + if len(image_idxs_local) > 0: + pixel_values_local = torch.cat([ + pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] + for i in image_idxs_local + ]) + else: + # Handle case where this rank has no images + pixel_values_local = torch.empty((0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype) + # embed_dim_reduction_factor = 2 * 2 + embed_dim_reduction_factor = (vision_model.spatial_merge_size * + vision_model.spatial_merge_size) + + # Find the max length across all ranks + # The output embedding of every DP rank has to be + # padded to this length for tensor_model_parallel_all_gather + # to work + max_len_per_rank = max( + grouped_pixel_values_len) // embed_dim_reduction_factor + local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] + + # Run the vision model on the local pixel_values_local + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model(pixel_values_local, + local_grid_thw_list) + else: + # Handle empty case + image_embeds_local = torch.empty((0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + # Pad the output based on max_len_per_rank + # for tensor_model_parallel_all_gather to work + current_len = image_embeds_local.shape[0] + if current_len < max_len_per_rank: + padding_size = max_len_per_rank - current_len + padding = torch.empty((padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], + dim=0) + else: + image_embeds_local_padded = image_embeds_local + + # Do all_gather to collect embeddings from all ranks + gathered_embeds = tensor_model_parallel_all_gather( + image_embeds_local_padded, dim=0) + + # Remove padding and reconstruct per-rank embeddings + rank_embeddings = list[torch.Tensor]() + for rank in range(tp_size): + start_idx = rank * max_len_per_rank + end_idx = start_idx + (grouped_pixel_values_len[rank] // + embed_dim_reduction_factor) + rank_embeddings.append(gathered_embeds[start_idx:end_idx]) + + patches_per_output_image = [(patch_size // embed_dim_reduction_factor) + for patch_size in patches_per_image] + + # Reconstruct embeddings in the original order + original_order_embeddings = [None] * len(grid_thw_list) + current_idx = 0 + for rank in range(tp_size): + count = gpu_sample_counts[rank] + if count > 0: + # Get images assigned to this rank in shuffled order + # GPU_0 = image_idxs_local [0] + # GPU_1 = image_idxs_local [2, 1, 3] + rank_images = image_to_tp_rank[current_idx:current_idx + count] + + rank_embed = rank_embeddings[rank] + # Split rank embeddings back to individual images + embed_start = 0 + for img_idx in rank_images: + img_patches = patches_per_output_image[img_idx] + original_order_embeddings[img_idx] = rank_embed[ + embed_start:embed_start + img_patches] + embed_start += img_patches + current_idx += count + + out_embeddings = tuple(embed for embed in original_order_embeddings + if embed is not None) + assert len(out_embeddings) == len( + original_order_embeddings), "Found unassigned embeddings" + return out_embeddings + + def fetch_audio( audio_url: str, audio_io_kwargs: Optional[dict[str, Any]] = None,