Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/models/vlm/qwen3-vl.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,35 @@ Note:
- For dataset formats and additional information, refer to the [Qwen2.5-VL documentation]
- See the full script with examples at [`examples/models/vlm/qwen_vl/finetune_qwen_vl.py`](../../../examples/models/vlm/qwen_vl/finetune_qwen_vl.py)

### PEFT (Parameter-Efficient Fine-Tuning)

Qwen3-VL supports PEFT methods including LoRA and DoRA for memory-efficient training. PEFT trains only adapter parameters (~1-2% of model), significantly reducing memory requirements and enabling faster training.

**LoRA with 8B Dense Model (1 GPU):**
```bash
torchrun --nproc-per-node=1 examples/models/vlm/qwen_vl/finetune_qwen_vl.py \
--pretrained-checkpoint $MEGATRON_MODEL_PATH \
--recipe qwen3_vl_8b_finetune_config \
--dataset-type hf \
--peft lora \
checkpoint.save=$SAVE_DIR/<experiment name>
```

**LoRA with 30B MoE Model (8 GPUs with Expert Parallelism):**
```bash
torchrun --nproc-per-node=8 examples/models/vlm/qwen_vl/finetune_qwen_vl.py \
--pretrained-checkpoint $MEGATRON_MODEL_PATH \
--recipe qwen3_vl_30b_a3b_finetune_config \
--dataset-type hf \
--peft lora \
checkpoint.save=$SAVE_DIR/<experiment name>
```

**DoRA Training:**
```bash
--peft dora
```
Comment on lines +89 to +116
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add implementation details for PEFT setup.

This section has motivation and examples, but it doesn’t yet explain how PEFT is wired in the recipe (configuration keys, defaults, and override behavior). The docs requirement for new key features explicitly calls for implementation details.

📌 Suggested doc addition
 Qwen3-VL supports PEFT methods including LoRA and DoRA for memory-efficient training. PEFT trains only adapter parameters (~1-2% of model), significantly reducing memory requirements and enabling faster training.

+Implementation details:
+- Enable PEFT via `+peft=lora` or `+peft=dora`, which maps to the recipe’s default PEFT configuration.
+- The recipe uses `finetune_lr` as the PEFT max LR; override it via CLI/YAML if needed.
+- For custom LoRA/DoRA hyperparameters, pass a PEFT object or extend the recipe defaults accordingly.

As per coding guidelines: All new key features (enabling a new model, enabling a new parallelism strategy) must include documentation update explaining motivation, technical approach, usage examples, and implementation details.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
### PEFT (Parameter-Efficient Fine-Tuning)
Qwen3-VL supports PEFT methods including LoRA and DoRA for memory-efficient training. PEFT trains only adapter parameters (~1-2% of model), significantly reducing memory requirements and enabling faster training.
**LoRA with 8B Dense Model (1 GPU):**
```bash
torchrun --nproc-per-node=1 examples/models/vlm/qwen_vl/finetune_qwen_vl.py \
--pretrained-checkpoint $MEGATRON_MODEL_PATH \
--recipe qwen3_vl_8b_finetune_config \
--dataset-type hf \
+peft=lora \
checkpoint.save=$SAVE_DIR/<experiment name>
```
**LoRA with 30B MoE Model (8 GPUs with Expert Parallelism):**
```bash
torchrun --nproc-per-node=8 examples/models/vlm/qwen_vl/finetune_qwen_vl.py \
--pretrained-checkpoint $MEGATRON_MODEL_PATH \
--recipe qwen3_vl_30b_a3b_finetune_config \
--dataset-type hf \
+peft=lora \
checkpoint.save=$SAVE_DIR/<experiment name>
```
**DoRA Training:**
```bash
+peft=dora
```
### PEFT (Parameter-Efficient Fine-Tuning)
Qwen3-VL supports PEFT methods including LoRA and DoRA for memory-efficient training. PEFT trains only adapter parameters (~1-2% of model), significantly reducing memory requirements and enabling faster training.
Implementation details:
- Enable PEFT via `+peft=lora` or `+peft=dora`, which maps to the recipe's default PEFT configuration.
- The recipe uses `finetune_lr` as the PEFT max LR; override it via CLI/YAML if needed.
- For custom LoRA/DoRA hyperparameters, pass a PEFT object or extend the recipe defaults accordingly.
**LoRA with 8B Dense Model (1 GPU):**
🤖 Prompt for AI Agents
In `@docs/models/vlm/qwen3-vl.md` around lines 89 - 116, Add explicit
implementation details for PEFT wiring: describe which config keys (e.g.,
peft.type, peft.lora.rank, peft.dora.params, peft.enabled) are read by the
finetune entrypoint finetune_qwen_vl.py and which recipe files (e.g.,
qwen3_vl_8b_finetune_config, qwen3_vl_30b_a3b_finetune_config) should set them;
explain defaults (peft.type=null, peft.enabled=false, sensible LoRA default
rank), how CLI overrides like +peft=lora or +peft=dora map to peft.type and
merged into the Hydra config, and how the training pipeline initializes adapters
(calls into the PEFT factory to attach LoRA/DoRA modules, freeze base params,
register adapter params for optimizer and checkpoint.save behavior). Include
guidance on checkpointing and parameter counting (only adapter params
saved/loaded when peft.enabled=true) and note where in the codebase to update if
the mapping changes (finetune_qwen_vl.py and the recipe files named above).


## Hugging Face Model Cards
- Qwen3-VL-8B: `https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct`
- Qwen3-VL-30B-A3B (MoE): `https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct`
Expand Down
30 changes: 22 additions & 8 deletions examples/models/vlm/qwen_vl/finetune_qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
action="store_true",
help="Use preloaded dataset provider (enabled automatically when --data-path is set).",
)
parser.add_argument(
"--peft",
type=str,
default=None,
choices=["lora", "dora", "none"],
help="Type of PEFT to use: 'lora', 'dora', or 'none' (full SFT). If not set, uses full SFT.",
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
args, cli_dotlist_overrides = parser.parse_known_args()
return args, cli_dotlist_overrides
Expand Down Expand Up @@ -216,14 +223,21 @@ def main() -> None:
use_preloaded_flag = bool(args.data_path) or bool(getattr(args, "use_preloaded", False))
dataset_type = args.dataset_type or ("preloaded" if use_preloaded_flag else "mock")

cfg: ConfigContainer = pretrain_config(
dataset_type=dataset_type,
train_data_path=args.data_path,
valid_data_path=None,
test_data_path=None,
image_folder=args.image_folder,
pretrained_checkpoint=args.pretrained_checkpoint,
)
# Build recipe kwargs
recipe_kwargs = {
"dataset_type": dataset_type,
"train_data_path": args.data_path,
"valid_data_path": None,
"test_data_path": None,
"image_folder": args.image_folder,
"pretrained_checkpoint": args.pretrained_checkpoint,
}

# Add peft parameter if specified via --peft flag
if args.peft is not None:
recipe_kwargs["peft"] = args.peft

cfg: ConfigContainer = pretrain_config(**recipe_kwargs)
logger.info("Loaded base configuration")

if get_rank_safe() == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def forward(
position_ids: torch.Tensor = None, # can set at dataset
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
loss_mask: torch.Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
Expand Down Expand Up @@ -319,6 +320,7 @@ def forward(
attention_mask=attention_mask, # None in encoder
decoder_input=combined_embeddings, # only not None in the first decoder PP stage
labels=labels, # only not None in the last decoder PP stage
loss_mask=loss_mask, # Added for THD training compatibility
inference_params=inference_params, # currently always None
packed_seq_params=packed_seq_params, # currently always None
visual_pos_masks=visual_pos_masks,
Expand Down
11 changes: 10 additions & 1 deletion src/megatron/bridge/recipes/qwen_vl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@

# Qwen3 models
from .qwen3_vl import (
qwen3_vl_8b_finetune_config,
qwen3_vl_8b_pretrain_config,
qwen3_vl_30b_a3b_finetune_config,
qwen3_vl_30b_a3b_pretrain_config,
qwen3_vl_235b_a22b_finetune_config,
qwen3_vl_235b_a22b_pretrain_config,
)


__all__ = [
# Qwen3-VL models
# Qwen3-VL pretrain configs
"qwen3_vl_8b_pretrain_config",
"qwen3_vl_30b_a3b_pretrain_config",
"qwen3_vl_235b_a22b_pretrain_config",
# Qwen3-VL finetune configs (with PEFT support)
"qwen3_vl_8b_finetune_config",
"qwen3_vl_30b_a3b_finetune_config",
"qwen3_vl_235b_a22b_finetune_config",
]
69 changes: 61 additions & 8 deletions src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
MockVLMConversationProvider,
PreloadedVLMConversationProvider,
)
from megatron.bridge.peft.base import PEFT
from megatron.bridge.recipes.qwen_vl.data.energon.task_encoder import QwenVLTaskEncoder
from megatron.bridge.recipes.utils.finetune_utils import default_peft_config as _default_peft_config
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE
from megatron.bridge.training.comm_overlap import CommOverlapConfig
Expand Down Expand Up @@ -97,10 +99,14 @@ class Qwen3VLCommonKwargs(TypedDict, total=False):
dataset_type: Optional[str]
image_folder: Optional[str]
tokenizer_model: Optional[str]
# PEFT options
peft: Optional[Union[str, PEFT]]
finetune_lr: float


def qwen3_vl_8b_pretrain_config(**user_kwargs: Unpack[Qwen3VLCommonKwargs]) -> ConfigContainer:
"""Return a pre-training config for Qwen3-VL 8B Instruct.

See `_qwen3_vl_common` for the full list of parameters.
"""
recommended_kwargs: Qwen3VLCommonKwargs = {
Expand Down Expand Up @@ -160,14 +166,25 @@ def qwen3_vl_235b_a22b_pretrain_config(**user_kwargs: Unpack[Qwen3VLCommonKwargs

def qwen3_vl_8b_finetune_config(**user_kwargs: Unpack[Qwen3VLCommonKwargs]) -> ConfigContainer:
"""Return a fine-tuning config for Qwen3-VL 8B Instruct.

Default configuration: 1 node, 8 GPUs
- LoRA/DoRA: TP=1, PP=1, LR=1e-4
- Full SFT: TP=4, PP=1, LR=1e-5

See `_qwen3_vl_common` for the full list of parameters.
"""
# Check if user is doing full SFT or PEFT
peft_value = user_kwargs.get("peft", None)
is_full_sft = peft_value is None or (isinstance(peft_value, str) and peft_value.lower() == "none")

recommended_kwargs: Qwen3VLCommonKwargs = {
"hf_path": "Qwen/Qwen3-VL-8B-Instruct",
"tensor_model_parallel_size": 4,
"tensor_model_parallel_size": 4 if is_full_sft else 1,
"pipeline_model_parallel_size": 1,
"pipeline_dtype": torch.bfloat16,
"expert_model_parallel_size": 1,
"peft": peft_value,
"finetune_lr": 1e-5 if is_full_sft else 1e-4,
"freeze_language_model": True,
"freeze_vision_model": True,
"freeze_vision_projection": False,
Expand All @@ -184,14 +201,27 @@ def qwen3_vl_8b_finetune_config(**user_kwargs: Unpack[Qwen3VLCommonKwargs]) -> C
def qwen3_vl_30b_a3b_finetune_config(**user_kwargs: Unpack[Qwen3VLCommonKwargs]) -> ConfigContainer:
"""Return a fine-tuning config for Qwen3-VL-30B-A3B-Instruct.

This is a Mixture-of-Experts model with 128 experts and top-8 routing.
Recommended to use with expert parallelism (EP) for efficient training.

Default configuration: 1 node, 8 GPUs
- LoRA/DoRA: TP=1, PP=1, EP=8, LR=2e-4
- Full SFT: TP=1, PP=1, EP=8, LR=2e-5

See `_qwen3_vl_common` for the full list of parameters.
"""
# Check if user is doing full SFT or PEFT
peft_value = user_kwargs.get("peft", None)
is_full_sft = peft_value is None or (isinstance(peft_value, str) and peft_value.lower() == "none")

recommended_kwargs: Qwen3VLCommonKwargs = {
"hf_path": "Qwen/Qwen3-VL-30B-A3B-Instruct",
"tensor_model_parallel_size": 1,
"pipeline_model_parallel_size": 1,
"pipeline_dtype": torch.bfloat16,
"expert_model_parallel_size": 8,
"peft": peft_value,
"finetune_lr": 2e-5 if is_full_sft else 2e-4,
"freeze_language_model": True,
"freeze_vision_model": True,
"freeze_vision_projection": True,
Expand All @@ -207,18 +237,30 @@ def qwen3_vl_30b_a3b_finetune_config(**user_kwargs: Unpack[Qwen3VLCommonKwargs])


def qwen3_vl_235b_a22b_finetune_config(**user_kwargs: Unpack[Qwen3VLCommonKwargs]) -> ConfigContainer:
"""Return a fine-tuning config for Qwen3-VL-30B-A3B-Instruct.
"""Return a fine-tuning config for Qwen3-VL-235B-A22B-Instruct.

This is a Mixture-of-Experts model with 128 experts and top-8 routing.
Recommended to use with expert parallelism (EP) for efficient training.

Default configuration: 4 nodes, 32 GPUs total
- LoRA/DoRA: TP=1, PP=1, EP=8, LR=2e-4
- Full SFT: TP=4, PP=1, EP=8, LR=2e-5

See `_qwen3_vl_common` for the full list of parameters.
"""
# Check if user is doing full SFT or PEFT
peft_value = user_kwargs.get("peft", None)
is_full_sft = peft_value is None or (isinstance(peft_value, str) and peft_value.lower() == "none")

recommended_kwargs: Qwen3VLCommonKwargs = {
"hf_path": "Qwen/Qwen3-VL-235B-A22B-Instruct",
"tensor_model_parallel_size": 1,
"pipeline_model_parallel_size": 8,
"tensor_model_parallel_size": 4 if is_full_sft else 1,
"pipeline_model_parallel_size": 1,
"pipeline_dtype": torch.bfloat16,
"account_for_embedding_in_pipeline_split": True,
"account_for_loss_in_pipeline_split": True,
"expert_model_parallel_size": 8,
"expert_tensor_parallel_size": 1,
"peft": peft_value,
"finetune_lr": 2e-5 if is_full_sft else 2e-4,
"freeze_language_model": True,
"freeze_vision_model": True,
"freeze_vision_projection": False,
Expand Down Expand Up @@ -282,6 +324,9 @@ def _qwen3_vl_common(
dataset_type: Optional[str] = None,
image_folder: Optional[str] = None,
tokenizer_model: Optional[str] = None,
# PEFT options
peft: Optional[Union[str, PEFT]] = None,
finetune_lr: Optional[float] = None,
) -> ConfigContainer:
"""
Create a pre-training configuration for Qwen3 MoE models using a given HuggingFace path.
Expand Down Expand Up @@ -327,6 +372,8 @@ def _qwen3_vl_common(
dataset_type (Optional[str]): Type of dataset to use.
image_folder (Optional[str]): Path to image folder.
tokenizer_model (Optional[str]): Path to tokenizer model.
peft (Optional[Union[str, PEFT]]): PEFT configuration (e.g., "lora", "dora", or PEFT object).
finetune_lr (Optional[float]): Learning rate override for fine-tuning.
Returns:
ConfigContainer: Configuration for pre-training.
"""
Expand Down Expand Up @@ -369,13 +416,18 @@ def _qwen3_vl_common(
model_cfg.seq_length = seq_length
model_cfg.cross_entropy_fusion_impl = "te"

# Optimizer and scheduler - use finetune_lr if provided, otherwise use lr
effective_lr = finetune_lr if finetune_lr is not None else lr
opt_config, scheduler = distributed_fused_adam_with_cosine_annealing(
lr_warmup_iters=lr_warmup_iters,
lr_decay_iters=lr_decay_iters,
max_lr=lr,
lr_decay_iters=lr_decay_iters if lr_decay_iters is not None else train_iters,
max_lr=effective_lr,
min_lr=min_lr,
)

# PEFT config
peft_config = _default_peft_config(peft)

# Determine dataset selection strategy.
_processor_model = tokenizer_model or hf_path
_dataset_choice = dataset_type or ("mock" if mock else "hf")
Expand Down Expand Up @@ -488,6 +540,7 @@ def _qwen3_vl_common(
fully_parallel_save=True,
),
rng=RNGConfig(seed=1234),
peft=peft_config,
comm_overlap=comm_overlap_config,
mixed_precision=precision_config,
)
Expand Down
7 changes: 4 additions & 3 deletions tests/functional_tests/L2_Launch_recipes_qwen_vl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ set -xeuo pipefail # Exit immediately if a command exits with a non-zero status

export CUDA_VISIBLE_DEVICES="0,1"

# Run Qwen2.5-VL recipe functional tests on 2 GPUs
# This script tests Qwen2.5-VL finetune recipe configurations with their default
# Run Qwen VL recipe functional tests on 2 GPUs
# This script tests Qwen2.5-VL and Qwen3-VL recipe configurations with their default
# settings to ensure they can run basic training without crashes.
uv run python -m torch.distributed.run --nproc_per_node=2 --nnodes=1 \
-m coverage run --data-file=/opt/Megatron-Bridge/.coverage \
--source=/opt/Megatron-Bridge/ --parallel-mode \
-m pytest -o log_cli=true -o log_cli_level=INFO -v -s -x \
-m "not pleasefixme" --tb=short -rA \
tests/functional_tests/recipes/test_qwen_vl_recipes_finetune.py
tests/functional_tests/recipes/test_qwen_vl_recipes_finetune.py \
tests/functional_tests/recipes/test_qwen3_vl_recipes_finetune.py

coverage combine -q

Expand Down
77 changes: 77 additions & 0 deletions tests/functional_tests/recipes/test_qwen3_vl_recipes_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functional smoke tests for Qwen3-VL finetuning recipes.

This test ensures that:
1. Qwen3-VL model forward pass works with all required parameters (including loss_mask)
2. Training loop completes without errors
3. Checkpoints are saved correctly

This catches regressions like missing parameters in the forward pass signature.

Run with:
torchrun --nproc_per_node=2 -m pytest tests/functional_tests/recipes/test_qwen3_vl_recipes_peft.py -v
"""

import pytest

from megatron.bridge.recipes.qwen_vl.qwen3_vl import qwen3_vl_8b_finetune_config
from tests.functional_tests.recipes.utils import run_pretrain_vl_recipe_test


QWEN3_VL_FINETUNE_RECIPES = [
# (config_func, recipe_name, parallelism_overrides, model_overrides)
# Qwen3-VL 8B finetune - uses TP=2 for 2-GPU CI
# Note: deepstack_visual_indexes must have len <= num_layers
(
qwen3_vl_8b_finetune_config,
"qwen3_vl_8b_finetune",
{"tensor_model_parallel_size": 2, "pipeline_model_parallel_size": 1},
{"num_layers": 4, "deepstack_visual_indexes": [0, 1, 2]},
),
]


class TestQwen3VLFinetuneRecipes:
"""Test class for Qwen3-VL finetune recipe functional tests."""

@pytest.mark.run_only_on("GPU")
@pytest.mark.parametrize(
"config_func,recipe_name,parallelism_overrides,model_overrides",
QWEN3_VL_FINETUNE_RECIPES,
)
def test_qwen3_vl_finetune_recipes(
self,
config_func,
recipe_name,
parallelism_overrides,
model_overrides,
tmp_path,
):
"""Functional test for Qwen3-VL finetune recipes.

This test runs a minimal training session to verify that:
1. The config loads correctly
2. Model forward pass accepts all required parameters (loss_mask, etc.)
3. Training completes without errors
4. Checkpoints are created
"""
run_pretrain_vl_recipe_test(
config_func,
recipe_name,
tmp_path,
model_overrides=model_overrides,
**parallelism_overrides,
)
Loading