Skip to content
209 changes: 209 additions & 0 deletions src/megatron/bridge/recipes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from megatron.core.distributed import DistributedDataParallelConfig

from megatron.bridge.data.vlm_datasets.hf_provider import HFDatasetConversationProvider
from megatron.bridge.peft.lora import LoRA
from megatron.bridge.recipes.utils.finetune_utils import default_squad_config
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE
from megatron.bridge.training.config import (
CheckpointConfig,
ConfigContainer,
Expand Down Expand Up @@ -335,3 +337,210 @@ def _peft_common() -> ConfigContainer:
)

return cfg


def _sft_common_vlm() -> ConfigContainer:
"""Create a base SFT ConfigContainer with common defaults for Vision-Language Models.

This function inherits from `_sft_common()` and overrides VLM-specific settings.
The caller MUST set `cfg.model` and `cfg.dataset.hf_processor_path` before use.

Key differences from LLM SFT (`_sft_common`):
- Uses HFDatasetConversationProvider with HuggingFace datasets (e.g., CORD-v2)
- Uses NullTokenizer (VLMs use processor instead of tokenizer)
- DDP config optimized for VLM training (no grad/param overlap)
- Supports freeze options for language_model, vision_model, vision_projection
- Different training defaults (train_iters=300000, GBS=32, MBS=2)
- Different RNG seed (1234)

Returns:
ConfigContainer: Base configuration template for VLM full SFT.
"""
# Start from the LLM SFT common config
cfg = _sft_common()

# Default output directories
base_output_dir = os.path.join(os.getcwd(), "nemo_experiments")
run_output_dir = os.path.join(base_output_dir, "default")
checkpoint_dir = os.path.join(run_output_dir, "checkpoints")
tensorboard_dir = os.path.join(run_output_dir, "tb_logs")

# Default sequence length for VLM
seq_length = 4096

# VLM-specific training config - longer training with different batch sizes
cfg.train.train_iters = 300000
cfg.train.global_batch_size = 32
cfg.train.micro_batch_size = 2
cfg.train.manual_gc = True
cfg.train.manual_gc_interval = 100
cfg.train.manual_gc_eval = 100

# VLM-specific validation config
cfg.validation.eval_interval = 500
cfg.validation.eval_iters = 32

# VLM-specific optimizer settings - higher LR for VLM training
opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing(
lr_warmup_iters=500,
lr_decay_iters=None, # Defaults to train_iters during validation
max_lr=3e-4,
min_lr=3e-5,
)
cfg.optimizer = opt_cfg
cfg.scheduler = scheduler_cfg

# VLM-specific DDP config - no overlap for VLMs
cfg.ddp = DistributedDataParallelConfig(
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
overlap_param_gather=False,
average_in_collective=True,
data_parallel_sharding_strategy="optim_grads_params",
use_distributed_optimizer=True,
)

# VLM-specific dataset - uses HuggingFace dataset provider
# hf_processor_path must be set by model-specific config
cfg.dataset = HFDatasetConversationProvider(
seq_length=seq_length,
hf_processor_path=None, # Must be set by model-specific config
maker_name="make_cord_v2_dataset",
num_workers=2,
dataloader_type="single",
data_sharding=True,
pin_memory=True,
persistent_workers=False,
pack_sequences_in_batch=True,
)

# VLM uses NullTokenizer - actual tokenization is handled by the processor
cfg.tokenizer = TokenizerConfig(
tokenizer_type="NullTokenizer",
vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE,
)

# VLM-specific logger config
cfg.logger = LoggerConfig(
log_interval=10,
tensorboard_dir=tensorboard_dir,
log_timers_to_tensorboard=True,
)

# VLM-specific checkpoint config
cfg.checkpoint.save_interval = 500
cfg.checkpoint.save = checkpoint_dir
cfg.checkpoint.load = checkpoint_dir
cfg.checkpoint.ckpt_format = "torch_dist"
cfg.checkpoint.fully_parallel_save = True

# VLM uses different RNG seed
cfg.rng = RNGConfig(seed=1234)

return cfg


def _peft_common_vlm() -> ConfigContainer:
"""Create a base PEFT ConfigContainer with LoRA defaults for Vision-Language Models.

This function inherits from `_peft_common()` and overrides VLM-specific settings.
The caller MUST set `cfg.model` and `cfg.dataset.hf_processor_path` before use.

Key differences from LLM PEFT (`_peft_common`):
- Uses HFDatasetConversationProvider with HuggingFace datasets (e.g., CORD-v2)
- Uses NullTokenizer (VLMs use processor instead of tokenizer)
- DDP config optimized for VLM training (no grad/param overlap)
- Supports freeze options for language_model, vision_model, vision_projection
- Different training defaults (train_iters=300000, GBS=32, MBS=2)
- Different RNG seed (1234)
- Higher LR (1e-4) for adapter training

Returns:
ConfigContainer: Base configuration template for VLM PEFT with LoRA.
"""
# Start from the LLM PEFT common config
cfg = _peft_common()

# Default output directories
base_output_dir = os.path.join(os.getcwd(), "nemo_experiments")
run_output_dir = os.path.join(base_output_dir, "default")
checkpoint_dir = os.path.join(run_output_dir, "checkpoints")
tensorboard_dir = os.path.join(run_output_dir, "tb_logs")

# Default sequence length for VLM
seq_length = 4096

# VLM-specific training config - longer training with different batch sizes
cfg.train.train_iters = 300000
cfg.train.global_batch_size = 32
cfg.train.micro_batch_size = 2
cfg.train.manual_gc = True
cfg.train.manual_gc_interval = 100
cfg.train.manual_gc_eval = 100

# VLM-specific validation config
cfg.validation.eval_interval = 500
cfg.validation.eval_iters = 32

# VLM-specific optimizer settings - higher LR for PEFT
opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing(
lr_warmup_iters=500,
lr_decay_iters=None, # Defaults to train_iters during validation
max_lr=1e-4, # Higher LR for adapter training
min_lr=1e-5,
)
cfg.optimizer = opt_cfg
cfg.scheduler = scheduler_cfg

# VLM-specific DDP config - no overlap for VLMs
cfg.ddp = DistributedDataParallelConfig(
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
overlap_param_gather=False,
average_in_collective=True,
data_parallel_sharding_strategy="optim_grads_params",
use_distributed_optimizer=True,
)

# VLM-specific dataset - uses HuggingFace dataset provider
# hf_processor_path must be set by model-specific config
cfg.dataset = HFDatasetConversationProvider(
seq_length=seq_length,
hf_processor_path=None, # Must be set by model-specific config
maker_name="make_cord_v2_dataset",
num_workers=2,
dataloader_type="single",
data_sharding=True,
pin_memory=True,
persistent_workers=False,
pack_sequences_in_batch=True,
)

# VLM uses NullTokenizer - actual tokenization is handled by the processor
cfg.tokenizer = TokenizerConfig(
tokenizer_type="NullTokenizer",
vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE,
)

# VLM-specific logger config
cfg.logger = LoggerConfig(
log_interval=10,
tensorboard_dir=tensorboard_dir,
log_timers_to_tensorboard=True,
)

# VLM-specific checkpoint config
cfg.checkpoint.save_interval = 500
cfg.checkpoint.save = checkpoint_dir
cfg.checkpoint.load = checkpoint_dir
cfg.checkpoint.ckpt_format = "torch_dist"
cfg.checkpoint.fully_parallel_save = True

# VLM uses different RNG seed
cfg.rng = RNGConfig(seed=1234)

# Keep LoRA config from _peft_common() - it's already set with standard defaults

return cfg
18 changes: 12 additions & 6 deletions src/megatron/bridge/recipes/gemma3_vl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
# limitations under the License.

from megatron.bridge.recipes.gemma3_vl.gemma3_vl import (
gemma3_vl_4b_finetune_config,
gemma3_vl_12b_finetune_config,
gemma3_vl_27b_finetune_config,
gemma3_vl_4b_peft_config,
gemma3_vl_4b_sft_config,
gemma3_vl_12b_peft_config,
gemma3_vl_12b_sft_config,
gemma3_vl_27b_peft_config,
gemma3_vl_27b_sft_config,
)


__all__ = [
"gemma3_vl_4b_finetune_config",
"gemma3_vl_12b_finetune_config",
"gemma3_vl_27b_finetune_config",
"gemma3_vl_4b_sft_config",
"gemma3_vl_12b_sft_config",
"gemma3_vl_27b_sft_config",
"gemma3_vl_4b_peft_config",
"gemma3_vl_12b_peft_config",
"gemma3_vl_27b_peft_config",
]
Loading