Skip to content
Merged
6 changes: 4 additions & 2 deletions examples/models/vlm/qwen3_vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,15 @@ Before training, ensure the following environment variables are set:

### Supervised Fine-Tuning (SFT)

See the [sft.sh](sft.sh) script for full parameter fine-tuning with configurable model parallelisms.
See the [sft_unpacked.sh](sft_seq_unpacked.sh) script for full parameter fine-tuning with configurable model parallelisms, with unpacked sequences..
See the [sft_packed.sh](sft_seq_packed.sh) script for for full parameter fine-tuning with sequence-packing.

W&B report coming soon.

### Parameter-Efficient Fine-Tuning (PEFT) with LoRA

See the [peft.sh](peft.sh) script for LoRA fine-tuning with configurable tensor and pipeline parallelism.
See the [peft_unpacked.sh](peft_seq_unpacked.sh) script for LoRA fine-tuning with configurable tensor and pipeline parallelism, with unpacked sequences.
See the [peft_packed.sh](peft_seq_packed.sh) script for LoRA fine-tuning with sequence-packing.

W&B report coming soon.

Expand Down
130 changes: 130 additions & 0 deletions examples/models/vlm/qwen3_vl/peft_seq_packed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/usr/bin/env bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Workspace directory for checkpoints and results
WORKSPACE=${WORKSPACE:-/workspace}

# Before training, make sure to set WANDB_API_KEY or disable wandb logging
# export WANDB_API_KEY=<your_wandb_api_key>
# export WANDB_MODE=disabled

# Test Seq Packing configurations for LoRA finetuning on the dense model
PRETRAINED_CHECKPOINT=${WORKSPACE}/models/Qwen3-VL-8B-Instruct
MODEL_NAME=qwen3_vl_8b
DATASET_NAME=cord_v2
SEQ_LENGTH=4096
TRAIN_ITERS=50
GLOBAL_BATCH_SIZE=32
MICRO_BATCH_SIZE=2
EVAL_ITERS=10
LR=0.00005
MIN_LR=0.000005
LR_WARMUP_ITERS=10
LOG_INTERVAL=1
WANDB_PROJECT=megatron-bridge-${DATASET_NAME}

SEQ_PACKING_CONFIGS=(True False)

# EP/TP/PP/CP combinations: "EP,TP,PP,CP" configurations
PARALLELISM_CONFIGS=("1,1,1,1" "1,1,1,2" "1,1,1,4")

for pack_config in "${SEQ_PACKING_CONFIGS[@]}"; do
for par_config in "${PARALLELISM_CONFIGS[@]}"; do
IFS=',' read -r EP TP PP CP <<< "$par_config"
echo "Running LoRA finetuning pack_sequences_in_batch=$pack_config with EP=$EP TP=$TP PP=$PP CP=$CP"
uv run python -m torch.distributed.run --nproc_per_node=8 scripts/training/run_recipe.py \
--recipe ${MODEL_NAME}_finetune_config \
--step_func qwen3_vl_step \
--peft_scheme lora \
checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
model.seq_length=$SEQ_LENGTH \
train.train_iters=$TRAIN_ITERS \
train.global_batch_size=$GLOBAL_BATCH_SIZE \
train.micro_batch_size=$MICRO_BATCH_SIZE \
train.eval_iters=$EVAL_ITERS \
optimizer.lr=$LR \
optimizer.min_lr=$MIN_LR \
scheduler.lr_warmup_iters=$LR_WARMUP_ITERS \
checkpoint.save=${WORKSPACE}/results/${MODEL_NAME}_lora_seq_pack_${pack_config}_cp${CP} \
logger.log_interval=$LOG_INTERVAL \
logger.wandb_project=$WANDB_PROJECT \
logger.wandb_exp_name=${MODEL_NAME}_${DATASET_NAME}_lora_seq_pack_${pack_config}_cp${CP} \
dataset.maker_name=make_${DATASET_NAME}_dataset \
dataset.seq_length=$SEQ_LENGTH \
dataset.pack_sequences_in_batch=$pack_config \
model.expert_model_parallel_size=$EP \
model.tensor_model_parallel_size=$TP \
model.pipeline_model_parallel_size=$PP \
model.context_parallel_size=$CP \
model.calculate_per_token_loss=True \
ddp.average_in_collective=False \
ddp.grad_reduce_in_fp32=True
done
done


# Test Seq Packing configurations for LoRA finetuning on the MoE model
PRETRAINED_CHECKPOINT=${WORKSPACE}/models/Qwen3-VL-30B-A3B-Instruct
MODEL_NAME=qwen3_vl_30b_a3b
DATASET_NAME=cord_v2
SEQ_LENGTH=4096
TRAIN_ITERS=50
GLOBAL_BATCH_SIZE=32
MICRO_BATCH_SIZE=2
EVAL_ITERS=10
LR=0.00005
MIN_LR=0.000005
LR_WARMUP_ITERS=10
LOG_INTERVAL=1
WANDB_PROJECT=megatron-bridge-${DATASET_NAME}

SEQ_PACKING_CONFIGS=(True False)

# EP/TP/PP/CP combinations: "EP,TP,PP,CP" configurations
PARALLELISM_CONFIGS=("8,1,1,1" "4,1,1,2" "2,1,1,4")

for pack_config in "${SEQ_PACKING_CONFIGS[@]}"; do
for par_config in "${PARALLELISM_CONFIGS[@]}"; do
IFS=',' read -r EP TP PP CP <<< "$par_config"
echo "Running LoRA finetuning pack_sequences_in_batch=$pack_config with EP=$EP TP=$TP PP=$PP CP=$CP"
uv run python -m torch.distributed.run --nproc_per_node=8 scripts/training/run_recipe.py \
--recipe ${MODEL_NAME}_finetune_config \
--step_func qwen3_vl_step \
--peft_scheme lora \
checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
model.seq_length=$SEQ_LENGTH \
train.train_iters=$TRAIN_ITERS \
train.global_batch_size=$GLOBAL_BATCH_SIZE \
train.micro_batch_size=$MICRO_BATCH_SIZE \
train.eval_iters=$EVAL_ITERS \
optimizer.lr=$LR \
optimizer.min_lr=$MIN_LR \
scheduler.lr_warmup_iters=$LR_WARMUP_ITERS \
checkpoint.save=${WORKSPACE}/results/${MODEL_NAME}_lora_seq_pack_${pack_config}_ep${EP}_cp${CP} \
logger.log_interval=$LOG_INTERVAL \
logger.wandb_project=$WANDB_PROJECT \
logger.wandb_exp_name=${MODEL_NAME}_${DATASET_NAME}_lora_seq_pack_${pack_config}_ep${EP}_cp${CP} \
dataset.maker_name=make_${DATASET_NAME}_dataset \
dataset.seq_length=$SEQ_LENGTH \
dataset.pack_sequences_in_batch=$pack_config \
model.expert_model_parallel_size=$EP \
model.tensor_model_parallel_size=$TP \
model.pipeline_model_parallel_size=$PP \
model.context_parallel_size=$CP \
model.calculate_per_token_loss=True \
ddp.average_in_collective=False \
ddp.grad_reduce_in_fp32=True
done
done
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ for config in "${PARALLELISM_CONFIGS[@]}"; do
echo "Running LoRA finetuning with TP=$TP, PP=$PP"
uv run python -m torch.distributed.run --nproc_per_node=2 scripts/training/run_recipe.py \
--recipe ${MODEL_NAME}_finetune_config \
--step_func vlm_step \
--step_func qwen3_vl_step \
--peft_scheme lora \
checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
model.seq_length=$SEQ_LENGTH \
Expand Down Expand Up @@ -82,15 +82,15 @@ LOG_INTERVAL=1
WANDB_PROJECT=megatron-bridge-${DATASET_NAME}

# EP/TP/PP combinations: "EP,TP,PP" configurations
PARALLELISM_CONFIGS=("8,1,1" "1,4,2")
PARALLELISM_CONFIGS=("8,1,1" "4,1,1" "2,1,1")

for config in "${PARALLELISM_CONFIGS[@]}"; do
IFS=',' read -r EP TP PP <<< "$config"

echo "Running LoRA finetuning with EP=$EP, TP=$TP, PP=$PP"
uv run python -m torch.distributed.run --nproc_per_node=8 scripts/training/run_recipe.py \
--recipe ${MODEL_NAME}_finetune_config \
--step_func vlm_step \
--step_func qwen3_vl_step \
--peft_scheme lora \
checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
model.seq_length=$SEQ_LENGTH \
Expand All @@ -111,4 +111,3 @@ for config in "${PARALLELISM_CONFIGS[@]}"; do
model.tensor_model_parallel_size=$TP \
model.pipeline_model_parallel_size=$PP
done

128 changes: 128 additions & 0 deletions examples/models/vlm/qwen3_vl/sft_seq_packed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Workspace directory for checkpoints and results
WORKSPACE=${WORKSPACE:-/workspace}

# Before training, make sure to set WANDB_API_KEY or disable wandb logging
# export WANDB_API_KEY=<your_wandb_api_key>
# export WANDB_MODE=disabled

# Test Seq Packing configurations for full finetuning on the dense model
PRETRAINED_CHECKPOINT=${WORKSPACE}/models/Qwen3-VL-8B-Instruct
MODEL_NAME=qwen3_vl_8b
DATASET_NAME=cord_v2
SEQ_LENGTH=4096
TRAIN_ITERS=50
GLOBAL_BATCH_SIZE=32
MICRO_BATCH_SIZE=2
EVAL_ITERS=10
LR=0.00005
MIN_LR=0.000005
LR_WARMUP_ITERS=10
LOG_INTERVAL=1
WANDB_PROJECT=megatron-bridge-${DATASET_NAME}

SEQ_PACKING_CONFIGS=(True False)

# EP/TP/PP/CP combinations: "EP,TP,PP,CP" configurations
PARALLELISM_CONFIGS=("1,1,1,1" "1,1,1,2" "1,1,1,4")

for pack_config in "${SEQ_PACKING_CONFIGS[@]}"; do
for par_config in "${PARALLELISM_CONFIGS[@]}"; do
IFS=',' read -r EP TP PP CP <<< "$par_config"
echo "Running full finetuning pack_sequences_in_batch=$pack_config with EP=$EP TP=$TP PP=$PP CP=$CP"
uv run python -m torch.distributed.run --nproc_per_node=8 scripts/training/run_recipe.py \
--recipe ${MODEL_NAME}_finetune_config \
--step_func qwen3_vl_step \
checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
model.seq_length=$SEQ_LENGTH \
train.train_iters=$TRAIN_ITERS \
train.global_batch_size=$GLOBAL_BATCH_SIZE \
train.micro_batch_size=$MICRO_BATCH_SIZE \
train.eval_iters=$EVAL_ITERS \
optimizer.lr=$LR \
optimizer.min_lr=$MIN_LR \
scheduler.lr_warmup_iters=$LR_WARMUP_ITERS \
checkpoint.save=${WORKSPACE}/results/${MODEL_NAME}_sft_seq_pack_${pack_config}_cp${CP} \
logger.log_interval=$LOG_INTERVAL \
logger.wandb_project=$WANDB_PROJECT \
logger.wandb_exp_name=${MODEL_NAME}_${DATASET_NAME}_sft_seq_pack_${pack_config}_cp${CP} \
dataset.maker_name=make_${DATASET_NAME}_dataset \
dataset.seq_length=$SEQ_LENGTH \
dataset.pack_sequences_in_batch=$pack_config \
model.expert_model_parallel_size=$EP \
model.tensor_model_parallel_size=$TP \
model.pipeline_model_parallel_size=$PP \
model.context_parallel_size=$CP \
model.calculate_per_token_loss=True \
ddp.average_in_collective=False \
ddp.grad_reduce_in_fp32=True
done
done


# Test Seq Packing configurations for full finetuning on the MoE model
PRETRAINED_CHECKPOINT=${WORKSPACE}/models/Qwen3-VL-30B-A3B-Instruct
MODEL_NAME=qwen3_vl_30b_a3b
DATASET_NAME=cord_v2
SEQ_LENGTH=4096
TRAIN_ITERS=50
GLOBAL_BATCH_SIZE=32
MICRO_BATCH_SIZE=2
EVAL_ITERS=10
LR=0.00005
MIN_LR=0.000005
LR_WARMUP_ITERS=10
LOG_INTERVAL=1
WANDB_PROJECT=megatron-bridge-${DATASET_NAME}

SEQ_PACKING_CONFIGS=(True False)

# EP/TP/PP/CP combinations: "EP,TP,PP,CP" configurations
PARALLELISM_CONFIGS=("8,1,1,1" "4,1,1,2" "2,1,1,4")

for pack_config in "${SEQ_PACKING_CONFIGS[@]}"; do
for par_config in "${PARALLELISM_CONFIGS[@]}"; do
IFS=',' read -r EP TP PP CP <<< "$par_config"
echo "Running full finetuning pack_sequences_in_batch=$pack_config with EP=$EP TP=$TP PP=$PP CP=$CP"
uv run python -m torch.distributed.run --nproc_per_node=8 scripts/training/run_recipe.py \
--recipe ${MODEL_NAME}_finetune_config \
--step_func qwen3_vl_step \
checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
model.seq_length=$SEQ_LENGTH \
train.train_iters=$TRAIN_ITERS \
train.global_batch_size=$GLOBAL_BATCH_SIZE \
train.micro_batch_size=$MICRO_BATCH_SIZE \
train.eval_iters=$EVAL_ITERS \
optimizer.lr=$LR \
optimizer.min_lr=$MIN_LR \
scheduler.lr_warmup_iters=$LR_WARMUP_ITERS \
checkpoint.save=${WORKSPACE}/results/${MODEL_NAME}_sft_seq_pack_${pack_config}_ep${EP}_cp${CP} \
logger.log_interval=$LOG_INTERVAL \
logger.wandb_project=$WANDB_PROJECT \
logger.wandb_exp_name=${MODEL_NAME}_${DATASET_NAME}_sft_seq_pack_${pack_config}_ep${EP}_cp${CP} \
dataset.maker_name=make_${DATASET_NAME}_dataset \
dataset.seq_length=$SEQ_LENGTH \
dataset.pack_sequences_in_batch=$pack_config \
model.expert_model_parallel_size=$EP \
model.tensor_model_parallel_size=$TP \
model.pipeline_model_parallel_size=$PP \
model.context_parallel_size=$CP \
model.calculate_per_token_loss=True \
ddp.average_in_collective=False \
ddp.grad_reduce_in_fp32=True
done
done
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ for config in "${PARALLELISM_CONFIGS[@]}"; do
echo "Running full finetuning with TP=$TP, PP=$PP"
uv run python -m torch.distributed.run --nproc_per_node=2 scripts/training/run_recipe.py \
--recipe ${MODEL_NAME}_finetune_config \
--step_func vlm_step \
--step_func qwen3_vl_step \
checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
model.seq_length=$SEQ_LENGTH \
train.train_iters=$TRAIN_ITERS \
Expand Down Expand Up @@ -89,7 +89,7 @@ for config in "${PARALLELISM_CONFIGS[@]}"; do
echo "Running full finetuning with EP=$EP, TP=$TP, PP=$PP, SP=$SP"
uv run python -m torch.distributed.run --nproc_per_node=8 scripts/training/run_recipe.py \
--recipe ${MODEL_NAME}_finetune_config \
--step_func vlm_step \
--step_func qwen3_vl_step \
checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
model.seq_length=$SEQ_LENGTH \
train.train_iters=$TRAIN_ITERS \
Expand Down
2 changes: 2 additions & 0 deletions scripts/training/run_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from typing import Callable

import megatron.bridge.recipes as recipes
from megatron.bridge.models.qwen_vl.qwen3_vl_step import forward_step as qwen3_vl_forward_step
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.finetune import finetune
from megatron.bridge.training.gpt_step import forward_step as gpt_forward_step
Expand All @@ -64,6 +65,7 @@
STEP_FUNCTIONS: dict[str, Callable] = {
"gpt_step": gpt_forward_step,
"vlm_step": vlm_forward_step,
"qwen3_vl_step": qwen3_vl_forward_step,
"llava_step": llava_forward_step,
}

Expand Down
2 changes: 2 additions & 0 deletions src/megatron/bridge/models/qwen_vl/qwen3_vl_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def forward_step(
original_tokens.shape[0], original_tokens.shape[1], dtype=torch.bool, device=original_tokens.device
)
forward_args["attention_mask"] = attention_mask
if forward_args["loss_mask"] is not None:
forward_args["loss_mask"] = forward_args["loss_mask"].reshape(1, -1)
# qwen3vl need the original input_ids and position_ids
# use split attention mask for calculate loss
forward_args["packed_seq_params"] = packed_seq_params
Expand Down