diff --git a/examples/models/vlm/qwen3_vl/README.md b/examples/models/vlm/qwen3_vl/README.md index 470b0d78f8..fd4c02d3a7 100644 --- a/examples/models/vlm/qwen3_vl/README.md +++ b/examples/models/vlm/qwen3_vl/README.md @@ -117,6 +117,20 @@ W&B report coming soon. **Note:** LoRA/DoRA significantly reduces memory requirements, allowing for larger batch sizes and fewer GPUs. +## Finetuning with Energon Dataset + +Follow the instructions [here](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/multimodal#pretraining) to prepare `LLaVA-Pretrain` dataset in Energon format. Change the file `.nv-meta/dataset.yaml` to the following: + +```yaml +__module__: megatron.bridge.recipes.qwen_vl.data.energon.task_encoder +__class__: ChatMLWebdataset +field_map: + imgs: jpg + conversation: json +``` + +Then, update the dataset path (`dataset.path=/path/to/energon/dataset`) in [energon_test.sh](energon_test.sh) and run the script. + ## Evaluation Coming soon. diff --git a/examples/models/vlm/qwen3_vl/energon_test.sh b/examples/models/vlm/qwen3_vl/energon_test.sh new file mode 100755 index 0000000000..10188ac86d --- /dev/null +++ b/examples/models/vlm/qwen3_vl/energon_test.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Workspace directory for checkpoints and results +WORKSPACE=${WORKSPACE:-/workspace} + +# Before training, make sure to set WANDB_API_KEY or disable wandb logging +# export WANDB_API_KEY= +# export WANDB_MODE=disabled + +# Test Seq Packing configurations for LoRA finetuning on the dense model +PRETRAINED_CHECKPOINT=${WORKSPACE}/models/Qwen3-VL-8B-Instruct +MODEL_NAME=qwen3_vl_8b +DATASET_NAME=energon +SEQ_LENGTH=4096 +TRAIN_ITERS=50 +GLOBAL_BATCH_SIZE=32 +MICRO_BATCH_SIZE=2 +EVAL_ITERS=10 +LR=0.00005 +MIN_LR=0.000005 +LR_WARMUP_ITERS=10 +LOG_INTERVAL=1 +WANDB_PROJECT=megatron-bridge-${DATASET_NAME} + +SEQ_PACKING_CONFIGS=(False True) + +# EP/TP/PP/CP/N_PROC combinations: "EP,TP,PP,CP,N_PROC" configurations +# N_PROC is the total number of processes (GPUs) used for training +# N_PROC is used to control DP size, to make the loss curves comparable +PARALLELISM_CONFIGS=("1,1,1,4,8" "1,1,1,2,4" "1,1,1,1,2") + +for pack_config in "${SEQ_PACKING_CONFIGS[@]}"; do + for par_config in "${PARALLELISM_CONFIGS[@]}"; do + IFS=',' read -r EP TP PP CP N_PROC <<< "$par_config" + echo "Running LoRA finetuning pack_sequences_in_batch=$pack_config with EP=$EP TP=$TP PP=$PP CP=$CP N_PROC=$N_PROC" + uv run python -m torch.distributed.run --nproc_per_node=$N_PROC scripts/training/run_recipe.py \ + --recipe ${MODEL_NAME}_finetune_config \ + --step_func qwen3_vl_step \ + --peft_scheme lora \ + --dataset_type energon \ + checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \ + model.seq_length=$SEQ_LENGTH \ + train.train_iters=$TRAIN_ITERS \ + train.global_batch_size=$GLOBAL_BATCH_SIZE \ + train.micro_batch_size=$MICRO_BATCH_SIZE \ + train.eval_iters=$EVAL_ITERS \ + optimizer.lr=$LR \ + optimizer.min_lr=$MIN_LR \ + scheduler.lr_warmup_iters=$LR_WARMUP_ITERS \ + checkpoint.save=${WORKSPACE}/results/${MODEL_NAME}_lora_seq_pack_${pack_config}_cp${CP} \ + logger.log_interval=$LOG_INTERVAL \ + logger.wandb_project=$WANDB_PROJECT \ + logger.wandb_exp_name=${MODEL_NAME}_${DATASET_NAME}_lora_seq_pack_${pack_config}_cp${CP} \ + dataset.seq_length=$SEQ_LENGTH \ + dataset.path=/path/to/energon/dataset \ + dataset.pack_sequences_in_batch=$pack_config \ + model.expert_model_parallel_size=$EP \ + model.tensor_model_parallel_size=$TP \ + model.pipeline_model_parallel_size=$PP \ + model.context_parallel_size=$CP \ + model.calculate_per_token_loss=True \ + ddp.average_in_collective=False \ + ddp.grad_reduce_in_fp32=True + done +done + diff --git a/examples/models/vlm/qwen3_vl/peft_seq_unpacked.sh b/examples/models/vlm/qwen3_vl/peft_seq_unpacked.sh old mode 100644 new mode 100755 diff --git a/scripts/training/run_recipe.py b/scripts/training/run_recipe.py index 5b2524e0be..bb4a676678 100755 --- a/scripts/training/run_recipe.py +++ b/scripts/training/run_recipe.py @@ -132,6 +132,12 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]: default=None, help="Sequence length for training", ) + parser.add_argument( + "--dataset_type", + type=str, + default=None, + help="Dataset type for VLM recipes (e.g., 'energon', 'mock', 'hf', 'preloaded').", + ) args, cli_overrides = parser.parse_known_args() return args, cli_overrides @@ -141,6 +147,7 @@ def load_recipe( peft_scheme: str | None, packed_sequence: bool = False, seq_length: int | None = None, + dataset_type: str | None = None, ) -> ConfigContainer: """ Load recipe by name from megatron.bridge.recipes. @@ -150,6 +157,7 @@ def load_recipe( peft_scheme: PEFT scheme to use ('lora', 'dora', or None) packed_sequence: Enable packed sequence training (default: False) seq_length: Sequence length for training (optional) + dataset_type: Dataset type for VLM recipes (e.g., 'energon', 'mock', 'hf', 'preloaded') Returns: ConfigContainer from calling the recipe @@ -175,11 +183,13 @@ def load_recipe( accepts_peft = "peft" in params or has_var_keyword accepts_packed_sequence = "packed_sequence" in params or has_var_keyword accepts_seq_length = "seq_length" in params or has_var_keyword + accepts_dataset_type = "dataset_type" in params or has_var_keyword except (ValueError, TypeError): # If signature inspection fails, fallback conservatively accepts_peft = True # peft is widely supported, try passing it accepts_packed_sequence = False # new parameter, don't pass if unsure accepts_seq_length = False # new parameter, don't pass if unsure + accepts_dataset_type = False # VLM-specific, don't pass if unsure # Build kwargs dynamically based on what the recipe accepts kwargs = {} @@ -189,6 +199,8 @@ def load_recipe( kwargs["packed_sequence"] = packed_sequence if accepts_seq_length and seq_length is not None: kwargs["seq_length"] = seq_length + if accepts_dataset_type and dataset_type is not None: + kwargs["dataset_type"] = dataset_type try: return config_builder(**kwargs) @@ -224,6 +236,7 @@ def main() -> None: args.peft_scheme, args.packed_sequence, args.seq_length, + args.dataset_type, ) config = process_config_with_overrides( diff --git a/src/megatron/bridge/data/energon/base_energon_datamodule.py b/src/megatron/bridge/data/energon/base_energon_datamodule.py index c46970c9d5..37d691ad44 100644 --- a/src/megatron/bridge/data/energon/base_energon_datamodule.py +++ b/src/megatron/bridge/data/energon/base_energon_datamodule.py @@ -15,7 +15,7 @@ import logging from typing import Any, Literal, Optional -from megatron.core import parallel_state +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset @@ -64,6 +64,7 @@ def __init__( decoder_seq_length: Optional[int] = None, packing_buffer_size: Optional[int] = None, validation_task_encoder: Optional[Any] = None, + pg_collection: Optional[ProcessGroupCollection] = None, **kwargs, ) -> None: """ @@ -89,6 +90,8 @@ def __init__( packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None. validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples for validation. Defaults to None and will be the same as task_encoder. + pg_collection (ProcessGroupCollection, optional): Process group collection for distributed training. + If provided, used instead of the global parallel_state. Defaults to None. **kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon """ @@ -112,8 +115,49 @@ def __init__( self.packing_buffer_size = packing_buffer_size self.validation_task_encoder = validation_task_encoder or self.task_encoder self.num_val_workers = num_val_workers or self.num_workers + self.pg_collection = pg_collection self.kwargs = kwargs + def _build_worker_config(self, num_workers: int, split: str = "train") -> WorkerConfig: + """Build a WorkerConfig using pg_collection, falling back to default_worker_config. + + NOTE: We intentionally use the pure DP rank (pg_collection.dp) + rather than the combined DP-CP rank. With Megatron's rank ordering + (default "tp-cp-ep-dp-pp"), all CP ranks within the same DP replica + already share the same pure DP rank. This ensures that CP ranks + processing different sequence portions of the same batch receive + identical data from the dataloader. + Using dp_cp would be INCORRECT here — it would assign each CP rank + a unique rank, causing them to read different data shards. + """ + if self.pg_collection is None or self.pg_collection.dp is None: + logger.info( + f"Multimodal {split} data loader pg_collection is not available, " + f"using default worker config with num_workers {num_workers}" + ) + return WorkerConfig.default_worker_config(num_workers) + + rank = self.pg_collection.dp.rank() + world_size = self.pg_collection.dp.size() + data_parallel_group = self.pg_collection.dp + cp_rank = self.pg_collection.cp.rank() if self.pg_collection.cp is not None else 0 + cp_size = self.pg_collection.cp.size() if self.pg_collection.cp is not None else 1 + + logger.info( + f"Multimodal {split} dataloader initializing with " + f"dp_rank {rank} dp_world_size {world_size} " + f"cp_rank {cp_rank} cp_size {cp_size} " + f"data_parallel_group {data_parallel_group}" + ) + return WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=num_workers, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): """ Provide the dataset for training or validation. @@ -165,28 +209,7 @@ def train_dataloader(self) -> Any: logger.info(f"Multimodal train dataloader initializing with init_global_step {self.init_global_step}") if self.train_dataloader_object: return self.train_dataloader_object - if not parallel_state.is_initialized(): - logger.info( - f"Muiltimodal data loader parallel state is not initialized," - f"using default worker config with no_workers {self.num_workers}" - ) - worker_config = WorkerConfig.default_worker_config(self.num_workers) - else: - rank = parallel_state.get_data_parallel_rank() - world_size = parallel_state.get_data_parallel_world_size() - data_parallel_group = parallel_state.get_data_parallel_group() - logger.info( - f" Multimodal train dataloader initializing with" - f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** " - ) - worker_config = WorkerConfig( - rank=rank, - world_size=world_size, - num_workers=self.num_workers, - data_parallel_group=data_parallel_group, - worker_debug_path=None, - worker_log_level=0, - ) + worker_config = self._build_worker_config(self.num_workers, split="train") train_dataset = self.datasets_provider(worker_config, split="train") energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) self.train_dataloader_object = energon_dataloader @@ -204,27 +227,7 @@ def val_dataloader(self): """ if self.val_dataloader_object: return self.val_dataloader_object - - if not parallel_state.is_initialized(): - logger.info( - f"Muiltimodal val data loader parallel state is not initialized," - f"using default worker config with no_workers {self.num_workers}" - ) - worker_config = WorkerConfig.default_worker_config(self.num_val_workers) - else: - rank = parallel_state.get_data_parallel_rank() - world_size = parallel_state.get_data_parallel_world_size() - data_parallel_group = parallel_state.get_data_parallel_group() - - logger.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") - worker_config = WorkerConfig( - rank=rank, - world_size=world_size, - num_workers=self.num_workers, - data_parallel_group=data_parallel_group, - worker_debug_path=None, - worker_log_level=0, - ) + worker_config = self._build_worker_config(self.num_val_workers, split="val") val_dataset = self.datasets_provider(worker_config, split="val") energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) self.val_dataloader_object = energon_loader diff --git a/src/megatron/bridge/data/energon/energon_provider.py b/src/megatron/bridge/data/energon/energon_provider.py index c128fc5c89..ff7cbdd22b 100644 --- a/src/megatron/bridge/data/energon/energon_provider.py +++ b/src/megatron/bridge/data/energon/energon_provider.py @@ -33,6 +33,8 @@ class EnergonProvider(DatasetProvider): num_workers: int_repr dataloader_type: str = "external" task_encoder: Optional[Any] = None + # Enable batch-level online sequence packing + pack_sequences_in_batch: bool = False def build_datasets(self, context: DatasetBuildContext): dataset = EnergonMultiModalDataModule( @@ -44,6 +46,7 @@ def build_datasets(self, context: DatasetBuildContext): micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, num_workers=self.num_workers, + pg_collection=context.pg_collection, ) return ( iter(dataset.train_dataloader()), diff --git a/src/megatron/bridge/data/utils.py b/src/megatron/bridge/data/utils.py index 0c258c1de9..a5695f03cc 100644 --- a/src/megatron/bridge/data/utils.py +++ b/src/megatron/bridge/data/utils.py @@ -189,6 +189,7 @@ def protocol_adapter( train_val_test_num_samples: list[int], config: DatasetProvider, tokenizer: Optional[MegatronTokenizer] = None, + pg_collection: Optional[ProcessGroupCollection] = None, ) -> tuple[Optional[Any], Optional[Any], Optional[Any]]: """Adapter function that bridges the protocol interface with the legacy interface.""" context = DatasetBuildContext( @@ -196,6 +197,7 @@ def protocol_adapter( valid_samples=train_val_test_num_samples[1], test_samples=train_val_test_num_samples[2], tokenizer=tokenizer, + pg_collection=pg_collection, ) return config.build_datasets(context) diff --git a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py index 05c163a24a..b6301ef5c5 100644 --- a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py +++ b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py @@ -19,14 +19,16 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Optional import numpy as np import torch from megatron.energon import Batch, DefaultTaskEncoder +from megatron.energon.epathlib.epath import EPath from megatron.energon.flavors.base_dataset import Sample -from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys -from PIL import Image +from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory +from transformers import BatchEncoding +from webdataset.autodecode import Decoder, imagehandler from megatron.bridge.training.utils.visual_inputs import Qwen2_5_VLVisualInputs @@ -114,14 +116,14 @@ def process_vision( kwargs["min_pixels"] = min_pixels if max_pixels is not None: kwargs["max_pixels"] = max_pixels - image_inputs = processor(images=images, videos=None, return_tensors="pt", **kwargs) + image_inputs = processor(images=images, text="", videos=None, return_tensors="pt", **kwargs) image_grid_thw = image_inputs.get("image_grid_thw", None) else: image_inputs = {} image_grid_thw = None if videos is not None: - videos_inputs = processor(images=None, videos=videos, return_tensors="pt") + videos_inputs = processor(images=None, text="", videos=videos, return_tensors="pt") video_grid_thw = videos_inputs.get("video_grid_thw", None) else: videos_inputs = {} @@ -152,15 +154,100 @@ def _get(token_str: str, default_id: int) -> int: return image_id, video_id +def _tensor_to_pil(t): + """Convert a [C,H,W] float tensor in [0,1] to a PIL Image (uint8 [0,255]).""" + from PIL import Image + + img_np = (t.permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_np) + + +def _images_to_pil(imgs): + """Convert WDS tensor images to PIL to match HF flow input format. + + WDS imagehandler decodes JPEG to float tensors in [0,1]. The HF flow passes + PIL images (uint8 [0,255]) to the processor. Converting to PIL here ensures + the processor applies identical rescaling and normalization in both flows. + """ + if isinstance(imgs, torch.Tensor): + if imgs.dim() == 3: + return [_tensor_to_pil(imgs)] + elif imgs.dim() == 4: + return [_tensor_to_pil(img) for img in imgs] + elif isinstance(imgs, list): + return [_tensor_to_pil(img) if isinstance(img, torch.Tensor) else img for img in imgs] + return imgs + + +def _videos_to_pil(videos): + """Convert WDS video frame tensors to PIL to match HF flow input format.""" + if videos is None: + return None + result = [] + for video in videos: + if isinstance(video, list): + result.append([_tensor_to_pil(f) if isinstance(f, torch.Tensor) else f for f in video]) + elif isinstance(video, torch.Tensor): + if video.dim() == 4: + result.append([_tensor_to_pil(f) for f in video]) + elif video.dim() == 3: + result.append([_tensor_to_pil(video)]) + else: + result.append([video]) + else: + result.append(video) + return result + + @dataclass class ChatMLSample(Sample): - """Intermediate Sample Format""" + """multi-turn complex samples with images and videos""" - # __key__: str - # __subflavors__: Dict - imgs: List[Image.Image] - videos: List[torch.Tensor | list[Image.Image]] conversation: str # JSON string of GPT-format conversations + imgs: Optional[List[torch.Tensor]] = None + videos: Optional[List[List[torch.Tensor]]] = None + + +class videohandler: + """Create a video handler.""" + + def __init__(self, imagespec): + self.extensions = ["jpgs", "mp4s", "videos"] + self.extensions_mapping = {"jpgs": "jpg", "mp4s": "jpg", "videos": "jpg"} + self.image_handler = imagehandler(imagespec) + + def __call__(self, key, data): + """Perform nested image decoding.""" + extension = re.sub(r".*[.]", "", key) + if extension.lower() not in self.extensions: + return None + data = pickle.loads(data) + key = self.extensions_mapping[extension] + if extension.lower() == "jpgs": + data = [self.image_handler(key, d) for d in data] + else: + data = [[self.image_handler(key, d) for d in video] for video in data] + return data + + +class ChatMLWebdataset(DefaultDecoderWebdatasetFactory[ChatMLSample]): + """Webdataset factory for multi-turn ChatML samples with multimodal support. + + Extends DefaultDecoderWebdatasetFactory to decode webdataset shards into + ChatMLSample instances, using custom handlers for image and video fields. + """ + + __sample_type__ = ChatMLSample + + def __init__(self, path: EPath, *, auto_decode: bool = True, **kwargs): + super().__init__(path, auto_decode=auto_decode, **kwargs) + if auto_decode: + self._decoder = Decoder( + [ + imagehandler(self.image_decode), + videohandler(self.image_decode), + ] + ) @dataclass @@ -231,51 +318,9 @@ def convert_to_qwenvl_content(user_input: str, image_pattern: str = "", v return contents -def cook_chatml_sample(sample: dict) -> ChatMLSample: - """ - Convert crude sampel to ChatMLSample. - - Args: - sample: Crude sample in pickle serialized format - - Returns: - sample in ChatMLSample format - """ - imgs = sample.get("jpgs", None) - if imgs: - imgs = pickle.loads(imgs) - if isinstance(imgs, list) and len(imgs) > 0: - imgs = [Image.fromarray(d) for d in imgs] - else: - imgs = None - videos = sample.get("videos", None) - if videos: - videos = pickle.loads(videos) - if isinstance(videos, list) and len(videos) > 0: - videos = [[d for d in video] for video in videos] - else: - videos = None - if "" in sample["json"] and imgs is None: - logging.warning(" in conversation text but no image data") - if "