diff --git a/src/megatron/bridge/peft/lora.py b/src/megatron/bridge/peft/lora.py index 3a8b3496fd..72000d2938 100644 --- a/src/megatron/bridge/peft/lora.py +++ b/src/megatron/bridge/peft/lora.py @@ -152,37 +152,3 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio ) return LoRALinear(module, adapter) return module - - -class LoRAMerge(PEFT): - """ - Implements the LoRA weight merge for parameter-efficient fine-tuning. - """ - - @torch.no_grad() - def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: - """ - Merges the LoRA adapter with the base model weights. - - Args: - m (nn.Module): The module to apply LoRA merge to. - name (str, optional): Name of the module to merge. Defaults to None. - prefix (str, optional): Prefix for the module name. Defaults to None. - - Returns: - nn.Module: The modified module with the LoRA adapter merged into the base model weights. - """ - - if not isinstance(module, LoRALinear): - return module - logging.info(f"merging {(prefix if prefix else '') + '.' + (name if name else '')}") - base_weight = module.to_wrap.weight - lora_weight = ( - module.adapter.alpha - / module.adapter.dim - * module.adapter.linear_out.weight.to(base_weight.device) - @ module.adapter.linear_in.weight.to(base_weight.device) - ) - merged_weight = base_weight + lora_weight - module.to_wrap.weight.data = merged_weight - return module diff --git a/src/megatron/bridge/peft/lora_merge.py b/src/megatron/bridge/peft/lora_merge.py new file mode 100644 index 0000000000..b5e7d9c1fe --- /dev/null +++ b/src/megatron/bridge/peft/lora_merge.py @@ -0,0 +1,337 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import replace +from typing import Optional + +import torch +import torch.nn as nn +from megatron.core.transformer.module import MegatronModule + +from megatron.bridge.peft.base import PEFT +from megatron.bridge.peft.lora_layers import LoRALinear +from megatron.bridge.training.checkpointing import save_checkpoint +from megatron.bridge.training.config import CheckpointConfig, ConfigContainer +from megatron.bridge.training.model_load_save import ( + load_megatron_model, + temporary_distributed_context, +) + + +class LoRAMerge(PEFT): + """ + Implements the LoRA weight merge for parameter-efficient fine-tuning. + """ + + @torch.no_grad() + def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: + """ + Merges the LoRA adapter with the base model weights. + + Args: + module (nn.Module): The module to apply LoRA merge to. + name (str, optional): Name of the module to merge. Defaults to None. + prefix (str, optional): Prefix for the module name. Defaults to None. + + Returns: + nn.Module: The modified module with the LoRA adapter merged into the base model weights. + """ + if not isinstance(module, LoRALinear): + return module + + logging.info(f"merging {(prefix if prefix else '') + '.' + (name if name else '')}") + lora_weight = ( + module.adapter.alpha + / module.adapter.dim + * module.adapter.linear_out.weight + @ module.adapter.linear_in.weight + ) + if hasattr(module.to_wrap, "weight"): + base_weight = module.to_wrap.weight + merged_weight = base_weight + lora_weight.to(base_weight.device) + module.to_wrap.weight.data = merged_weight + else: # TE Grouped Linear + for i in range(module.to_wrap.num_gemms): + base_weight = getattr(module.to_wrap, f"weight{i}") + merged_weight = base_weight + lora_weight.to(base_weight.device) + getattr(module.to_wrap, f"weight{i}").data = merged_weight + return module + + +def merge_lora(lora_checkpoint_path: str, output_path: str) -> None: + """ + Merge LoRA adapter weights into base model weights, preserving all metadata. + + This function loads a LoRA checkpoint, extracts the base model and adapter weights, + merges them into a single model, and saves the result as a Megatron-Bridge checkpoint + that can be used for inference or further training without PEFT. + + Args: + lora_checkpoint_path: Path to LoRA checkpoint (specific iteration directory like iter_0000100) + containing run_config.yaml, adapter weights and metadata + output_path: Path to save merged checkpoint directory + + Example: + >>> from megatron.bridge.peft.lora_merge import merge_lora + >>> merge_lora("/path/to/lora_checkpoint", "/path/to/merged_checkpoint") + + Note: + - Uses CPU initialization and Gloo backend for memory efficiency + - Preserves all original model metadata and configuration + - Removes PEFT configuration from merged checkpoint + - Output checkpoint is in torch_dist format + - Automatically resolves base pretrained directories to latest iteration + """ + + def _merge_lora_internal(): + config_from_checkpoint = _load_full_config_container_from_checkpoint(lora_checkpoint_path) + + # Extract pretrained checkpoint path and PEFT config + pretrained_path = config_from_checkpoint.checkpoint.pretrained_checkpoint + peft_config = config_from_checkpoint.peft + + if not pretrained_path: + raise ValueError(f"No pretrained_checkpoint found in {lora_checkpoint_path}") + if not peft_config: + raise ValueError(f"No PEFT configuration found in {lora_checkpoint_path}") + + # Resolve pretrained path to specific iteration if it's a base directory + resolved_pretrained_path = _resolve_checkpoint_path(pretrained_path) + base_model = load_megatron_model( + checkpoint_path=resolved_pretrained_path, + use_cpu_init=True, + return_state_dict=False, + skip_temp_dist_context=True, + ) + + # Ensure base_model is a list for consistent handling + if not isinstance(base_model, list): + base_model = [base_model] + + # Apply PEFT transformation and load adapter weights + _apply_adapters_to_model_with_full_config(base_model, config_from_checkpoint) + + # Apply LoRAMerge transformation + lora_merge = LoRAMerge() + merged_model = lora_merge(base_model, training=False) + + merged_config = _prepare_merged_config(config_from_checkpoint, output_path) + + _save_merged_model_with_metadata(merged_model, merged_config, output_path) + + # Check if distributed is already initialized + skip_temp_context = torch.distributed.is_available() and torch.distributed.is_initialized() + + if skip_temp_context: + # Already in distributed context, but ensure Megatron parallel state is initialized + _ensure_megatron_parallel_state_initialized() + _merge_lora_internal() + else: + # Use temporary distributed context for CPU-based merging + with temporary_distributed_context(backend="gloo"): + _merge_lora_internal() + + +def _ensure_megatron_parallel_state_initialized() -> None: + """Ensure Megatron model parallel state is initialized for merge operations.""" + from megatron.core import parallel_state + + if not parallel_state.model_parallel_is_initialized(): + parallel_state.initialize_model_parallel() + + +def _load_full_config_container_from_checkpoint(lora_checkpoint_path: str) -> ConfigContainer: + """Load the complete ConfigContainer from PEFT checkpoint, preserving all metadata.""" + from megatron.bridge.training.config import ConfigContainer + from megatron.bridge.training.utils.checkpoint_utils import ( + file_exists, + get_checkpoint_run_config_filename, + read_run_config, + ) + + # Read run config from LoRA checkpoint + run_config_filename = get_checkpoint_run_config_filename(lora_checkpoint_path) + if not file_exists(run_config_filename): + raise ValueError( + f"Run config not found in {lora_checkpoint_path}. Expected Megatron-Bridge checkpoint format." + ) + + run_config = read_run_config(run_config_filename) + + from megatron.bridge.utils.instantiate_utils import InstantiationMode + + config_container = ConfigContainer.from_dict(run_config, mode=InstantiationMode.LENIENT) + + # Modify checkpoint config for merging + config_container.checkpoint = _modify_checkpoint_config_for_merge( + config_container.checkpoint, lora_checkpoint_path + ) + + return config_container + + +def _modify_checkpoint_config_for_merge( + original_checkpoint_config: CheckpointConfig, lora_checkpoint_path: str +) -> CheckpointConfig: + """Modify checkpoint config for merge operation while preserving metadata.""" + + # Create a modified checkpoint config for merging + # - load from the LoRA checkpoint path (to get adapter weights) + # - don't load optimizer/rng (not needed for merging) + modified_config = replace( + original_checkpoint_config, + load=lora_checkpoint_path, # Load from LoRA checkpoint + load_optim=False, # Don't need optimizer + load_rng=False, # Don't need RNG state + ) + + return modified_config + + +def _prepare_merged_config(original_config: ConfigContainer, output_path: str) -> ConfigContainer: + """Prepare config for merged checkpoint, removing PEFT but preserving everything else.""" + + # Update checkpoint config for the merged model + merged_checkpoint_config = replace( + original_config.checkpoint, + save=output_path, # New save location + pretrained_checkpoint=None, # Merged model is now self-contained + save_optim=False, # Don't save optimizer in merged checkpoint + save_rng=False, # Don't save RNG in merged checkpoint + ckpt_format="torch_dist", + ) + + # Create merged config (same as original but without PEFT) + merged_config = replace( + original_config, + checkpoint=merged_checkpoint_config, + peft=None, # Remove PEFT - merged model doesn't need adapters anymore + ) + + return merged_config + + +def _save_merged_model_with_metadata(model: list, config: ConfigContainer, output_path: str) -> None: + """Save merged model using the full config to preserve all metadata.""" + from megatron.bridge.training.checkpointing import init_checkpointing_context + from megatron.bridge.training.state import GlobalState, TrainState + + # Create GlobalState with config and train state initialized + state = GlobalState() + state.cfg = config + state.train_state = TrainState() + + # Initialize checkpointing context + checkpointing_context = init_checkpointing_context(config.checkpoint) + + # Use the same save_checkpoint logic as training to preserve metadata + save_checkpoint( + state=state, + model=model, + optimizer=None, # No optimizer for merged model + opt_param_scheduler=None, # No scheduler for merged model + num_floating_point_operations_so_far=0, + checkpointing_context=checkpointing_context, + ) + + +def _apply_adapters_to_model_with_full_config(model: list[MegatronModule], full_config: ConfigContainer) -> None: + """Apply PEFT transformation and load adapter weights using full config.""" + peft_config = full_config.peft + lora_checkpoint_path = full_config.checkpoint.load + + # Apply original PEFT transformation from LoRA checkpoint config + transformed_model = peft_config(model, training=False) + + # Load only adapter weights + _load_adapter_weights_only(transformed_model, lora_checkpoint_path, peft_config) + + +def _load_adapter_weights_only(model: list[MegatronModule], lora_checkpoint_path: str, peft_config: PEFT) -> None: + """Load only adapter weights directly using distributed checkpoint loading.""" + from megatron.core import dist_checkpointing + + from megatron.bridge.training.checkpointing import ( + _generate_model_state_dict, + _load_state_dict_into_model_list, + apply_peft_adapter_filter_to_state_dict, + get_default_load_sharded_strategy, + ) + + # Ensure model is a list for consistent handling + if not isinstance(model, list): + model = [model] + + # Generate model state dict template from transformed model (with adapters) + complete_sharded_state_dict = _generate_model_state_dict(model) + + # Filter sharded state dict keys to only load adapter states + filtered_sharded_state_dict = apply_peft_adapter_filter_to_state_dict(complete_sharded_state_dict, peft_config) + + # Load adapter weights directly using distributed checkpoint + load_strategy = get_default_load_sharded_strategy(lora_checkpoint_path) + loaded_state_dict = dist_checkpointing.load( + filtered_sharded_state_dict, + lora_checkpoint_path, + load_strategy, + strict=dist_checkpointing.validation.StrictHandling.LOG_UNEXPECTED, + ) + + # Load adapter weights into model using shared utility + _load_state_dict_into_model_list(model, loaded_state_dict, strict=False) + + +def _resolve_checkpoint_path(checkpoint_path: str) -> str: + """ + Resolve checkpoint path to specific iteration directory. + + If checkpoint_path is a base directory, resolve to the latest iteration. + If checkpoint_path is already a specific iteration directory, return as-is. + + Args: + checkpoint_path: Either base checkpoint directory or specific iteration directory + + Returns: + Path to specific iteration directory containing weights and config + """ + import os + + from megatron.bridge.training.checkpointing import TRACKER_PREFIX + from megatron.bridge.training.utils.checkpoint_utils import ( + CONFIG_FILE, + file_exists, + get_checkpoint_name, + get_checkpoint_train_state_filename, + read_train_state, + ) + + # Check if this is already a specific iteration directory (contains run config) + run_config_file = os.path.join(checkpoint_path, CONFIG_FILE) + if file_exists(run_config_file): + # This is already a specific iteration directory + return checkpoint_path + + tracker_filename = get_checkpoint_train_state_filename(checkpoint_path, prefix=TRACKER_PREFIX) + if not file_exists(tracker_filename): + raise ValueError( + f"Cannot resolve checkpoint path {checkpoint_path}. " + f"Expected either a specific iteration directory (containing run_config.yaml) " + f"or a base directory with tracker file (latest_train_state.pt)" + ) + + train_state = read_train_state(tracker_filename) + iteration = train_state.step + resolved_path = get_checkpoint_name(checkpoint_path, iteration, release=False) + return resolved_path diff --git a/src/megatron/bridge/training/checkpointing.py b/src/megatron/bridge/training/checkpointing.py index 92a3a0aa61..0adc74a752 100644 --- a/src/megatron/bridge/training/checkpointing.py +++ b/src/megatron/bridge/training/checkpointing.py @@ -87,7 +87,7 @@ except Exception: has_nvidia_modelopt = False -TRACKER_PREFIX = "latest" +TRACKER_PREFIX: str = "latest" _CHECKPOINT_VERSION = None logger = getLogger(__name__) @@ -928,6 +928,20 @@ def _load_model_state_dict(module: torch.nn.Module, state_dict: dict[str, Any], print(f"load_return: {load_return}") +def _load_state_dict_into_model_list(model: list[MegatronModule], state_dict: dict[str, Any], strict: bool) -> None: + """Load state dict into a list of model modules, handling both single and pipeline parallel cases.""" + if len(model) == 1: + _load_model_state_dict(model[0], state_dict["model"], strict) + else: + for i in range(len(model)): + # If there is no corresponding model in the state_dict, it will be ignored. + # It means that this is an empty stage. + model_key = "model%d" % i + if model_key not in state_dict: + continue + _load_model_state_dict(model[i], state_dict[model_key], strict) + + def _load_checkpoint_from_path( load_dir: str, state: GlobalState, @@ -1143,16 +1157,7 @@ def _load_checkpoint_from_path( # Model. if not skip_load_to_model_and_opt: load_strict = False if is_peft_resume else strict - if len(model) == 1: - _load_model_state_dict(model[0], state_dict["model"], load_strict) - else: - for i in range(len(model)): - # If there is no corresponding model in the state_dict, it will be ignored. - # It means that this is an empty stage. - model_key = "model%d" % i - if model_key not in state_dict: - continue - _load_model_state_dict(model[i], state_dict[model_key], load_strict) + _load_state_dict_into_model_list(model, state_dict, strict=load_strict) # Fix up query/key/value matrix ordering if needed. checkpoint_version = get_checkpoint_version() diff --git a/src/megatron/bridge/training/model_load_save.py b/src/megatron/bridge/training/model_load_save.py index 807e5a2eee..7abe4e4438 100644 --- a/src/megatron/bridge/training/model_load_save.py +++ b/src/megatron/bridge/training/model_load_save.py @@ -244,6 +244,9 @@ def build_and_load_model( def _call_model_provider(model_cfg): """Handles provider call for both MBridge and MLM providers.""" if isinstance(model_cfg, ModelProviderMixin): + if mbridge_ckpt: + # Skip initializing the weights as they will be overwritten when loading the checkpoint + model_cfg.perform_initialization = False return model_cfg.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=use_cpu_init) else: assert model_type in ("gpt", "mamba"), f"model type {model_type} not supported." diff --git a/src/megatron/bridge/training/utils/checkpoint_utils.py b/src/megatron/bridge/training/utils/checkpoint_utils.py index 6dcfe87598..f50cf512c4 100644 --- a/src/megatron/bridge/training/utils/checkpoint_utils.py +++ b/src/megatron/bridge/training/utils/checkpoint_utils.py @@ -26,9 +26,9 @@ from megatron.bridge.utils.common_utils import get_rank_safe, get_world_size_safe, print_rank_0 -TRAIN_STATE_FILE = "train_state.pt" -TRACKER_PREFIX = "latest" -CONFIG_FILE = "run_config.yaml" +TRAIN_STATE_FILE: str = "train_state.pt" +TRACKER_PREFIX: str = "latest" +CONFIG_FILE: str = "run_config.yaml" logger = logging.getLogger(__name__) diff --git a/src/megatron/bridge/utils/instantiate_utils.py b/src/megatron/bridge/utils/instantiate_utils.py index 75efde252e..f8d8a65f10 100644 --- a/src/megatron/bridge/utils/instantiate_utils.py +++ b/src/megatron/bridge/utils/instantiate_utils.py @@ -331,6 +331,10 @@ def _call_target( ) -> Any: """Call target (type) with args and kwargs.""" args, kwargs = _extract_pos_args(args, kwargs) + + # Filter out init=False fields for dataclass constructors + kwargs = _filter_init_false_fields(_target_, kwargs) + if _partial_: try: return functools.partial(_target_, *args, **kwargs) @@ -349,6 +353,23 @@ def _call_target( raise InstantiationException(msg) from e +def _filter_init_false_fields(target: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]: + """Filter out fields marked with init=False from kwargs for dataclass constructors.""" + from dataclasses import fields, is_dataclass + + # Only filter for dataclasses + if not is_dataclass(target): + return kwargs + + # Get fields that should not be passed to constructor (init=False) + init_false_fields = {field.name for field in fields(target) if not field.init} + + # Filter out init=False fields from kwargs + filtered_kwargs = {key: value for key, value in kwargs.items() if key not in init_false_fields} + + return filtered_kwargs + + def _convert_target_to_string(t: Any) -> Any: if callable(t): return f"{t.__module__}.{t.__qualname__}" diff --git a/tests/functional_tests/training/test_lora_merge.py b/tests/functional_tests/training/test_lora_merge.py new file mode 100644 index 0000000000..0f8ad69e33 --- /dev/null +++ b/tests/functional_tests/training/test_lora_merge.py @@ -0,0 +1,401 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass + +import pytest +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig + +from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig +from megatron.bridge.data.hf_processors.squad import process_squad_example +from megatron.bridge.models.llama import Llama3ModelProvider +from megatron.bridge.peft.lora import LoRA +from megatron.bridge.peft.lora_merge import merge_lora +from megatron.bridge.training.checkpointing import ( + get_checkpoint_run_config_filename, + read_run_config, +) +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + MockGPTDatasetConfig, + RNGConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.finetune import finetune +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.initialize import destroy_global_state +from megatron.bridge.training.model_load_save import load_megatron_model +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.utils.checkpoint_utils import file_exists +from megatron.bridge.utils.instantiate_utils import instantiate +from tests.functional_tests.utils import ( + broadcast_path, + clear_directories, + get_directory_size, + initialize_distributed, + verify_checkpoint_files, +) + + +@dataclass +class Llama3ModelProvider145M(Llama3ModelProvider): + """Smaller Llama3 model for testing.""" + + rotary_base: int = 500_000 + num_layers: int = 2 + hidden_size: int = 768 + ffn_hidden_size: int = 2688 + num_attention_heads: int = 16 + make_vocab_size_divisible_by: int = 128 + + +class TestLoRAMerge: + """ + Test LoRA merge functionality with comprehensive end-to-end verification. + + Tests the complete pipeline: pretrain -> LoRA finetune -> merge -> verify + including checkpoint loading, config preservation, and size validation. + """ + + @pytest.mark.run_only_on("GPU") + def test_lora_merge_end_to_end(self, tmp_path): + """Test complete LoRA merge pipeline: pretrain -> finetune -> merge -> verify all aspects.""" + initialize_distributed() + shared_base_dir = broadcast_path(tmp_path) + + ( + pretrain_checkpoint_dir, + pretrain_tensorboard_dir, + lora_checkpoint_dir, + lora_tensorboard_dir, + merged_checkpoint_dir, + ) = self._setup_directories(shared_base_dir) + + torch.distributed.barrier() + + try: + seq_length = 512 + pretrain_iters = 10 + lora_iters = 5 + + # Create pretrain config and run + pretrain_cfg = self._create_pretrain_config( + pretrain_iters, pretrain_checkpoint_dir, pretrain_tensorboard_dir, seq_length + ) + pretrain(pretrain_cfg, forward_step) + verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters) + + # Create LoRA config and run finetuning + lora_cfg = self._create_lora_config( + lora_iters, lora_checkpoint_dir, lora_tensorboard_dir, pretrain_checkpoint_dir, seq_length + ) + finetune(lora_cfg, forward_step) + verify_checkpoint_files(lora_checkpoint_dir, lora_iters) + + # Merge LoRA checkpoint + lora_final_checkpoint = os.path.join(lora_checkpoint_dir, f"iter_{lora_iters:07d}") + merge_lora(lora_final_checkpoint, merged_checkpoint_dir) + + pretrain_final_checkpoint = os.path.join(pretrain_checkpoint_dir, f"iter_{pretrain_iters:07d}") + + # Merged checkpoint is saved as iteration 0 + merged_final_checkpoint = os.path.join(merged_checkpoint_dir, "iter_0000000") + + self._verify_merged_checkpoint_loading(merged_final_checkpoint) + self._verify_config_preservation(pretrain_final_checkpoint, lora_final_checkpoint, merged_final_checkpoint) + self._verify_checkpoint_sizes(pretrain_final_checkpoint, lora_final_checkpoint, merged_final_checkpoint) + + finally: + clear_directories(shared_base_dir) + destroy_global_state() + + def _verify_merged_checkpoint_loading(self, merged_checkpoint_dir: str) -> None: + """Verify that the merged checkpoint can be loaded successfully.""" + merged_model = load_megatron_model( + checkpoint_path=merged_checkpoint_dir, use_cpu_init=True, return_state_dict=False + ) + assert merged_model is not None, "Failed to load merged checkpoint" + + def _verify_config_preservation( + self, original_pretrain_dir: str, lora_checkpoint_dir: str, merged_checkpoint_dir: str + ) -> None: + """Verify that all configuration is preserved except PEFT and model architecture matches original.""" + + # Load all configs + lora_run_config = read_run_config(get_checkpoint_run_config_filename(lora_checkpoint_dir)) + merged_run_config = read_run_config(get_checkpoint_run_config_filename(merged_checkpoint_dir)) + + # Merged checkpoint config has no PEFT + assert file_exists(get_checkpoint_run_config_filename(merged_checkpoint_dir)), ( + "Merged checkpoint missing run_config.yaml" + ) + assert "peft" not in merged_run_config or merged_run_config["peft"] is None, ( + "Merged checkpoint should not contain PEFT configuration" + ) + + # All non-PEFT, non-checkpoint configs should be preserved from LoRA checkpoint + preserve_configs = ( + "model", + "train", + "optimizer", + "scheduler", + "dataset", + "logger", + "tokenizer", + "ddp", + "dist", + "rng", + ) + + for config_key in preserve_configs: + if config_key in lora_run_config: + assert config_key in merged_run_config, f"Missing preserved config: {config_key}" + # For complex comparison, just check they can be instantiated the same way + try: + lora_cfg_obj = instantiate(lora_run_config[config_key]) + merged_cfg_obj = instantiate(merged_run_config[config_key]) + assert type(lora_cfg_obj) == type(merged_cfg_obj), f"Config type mismatch for {config_key}" + except Exception as e: + pytest.fail(f"Failed to instantiate preserved config {config_key}: {e}") + + # Checkpoint config should be updated appropriately + merged_ckpt_cfg = merged_run_config["checkpoint"] + assert merged_ckpt_cfg.get("pretrained_checkpoint") is None, ( + "pretrained_checkpoint should be None in merged config" + ) + + def _verify_checkpoint_sizes( + self, original_pretrain_dir: str, lora_checkpoint_dir: str, merged_checkpoint_dir: str + ) -> None: + """Verify checkpoint sizes are as expected after merge.""" + + # Get checkpoint sizes + pretrain_size = get_directory_size(original_pretrain_dir) + lora_size = get_directory_size(lora_checkpoint_dir) + merged_size = get_directory_size(merged_checkpoint_dir) + + # LoRA checkpoint should be significantly smaller than pretrain (adapters vs full model) + lora_to_pretrain_ratio = lora_size / pretrain_size + assert lora_to_pretrain_ratio < 0.2, ( + f"LoRA checkpoint should be <20% of pretrain size, got {lora_to_pretrain_ratio:.1%}" + ) + + # Merged checkpoint should be similar size to pretrain (both contain only model weights) + size_ratio = merged_size / pretrain_size + assert 0.95 <= size_ratio <= 1.05, ( + f"Merged checkpoint should be 95-105% of pretrain size (both model weights only), got {size_ratio:.1%}" + ) + + def _create_model_provider(self, seq_length=512, tensor_parallel_size=1, pipeline_parallel_size=1): + """Create a model provider with specified configuration.""" + return Llama3ModelProvider145M( + seq_length=seq_length, + tensor_model_parallel_size=tensor_parallel_size, + pipeline_model_parallel_size=pipeline_parallel_size, + pipeline_dtype=torch.bfloat16, + sequence_parallel=(tensor_parallel_size > 1), + ) + + def _create_training_config(self, train_iters, global_batch_size=8, micro_batch_size=1): + """Create a training configuration.""" + return TrainingConfig( + train_iters=train_iters, + eval_interval=5, + eval_iters=0, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ) + + def _create_optimizer_config(self, lr=3e-3): + """Create an optimizer configuration.""" + return OptimizerConfig( + optimizer="adam", + bf16=True, + fp16=False, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + use_distributed_optimizer=True, + clip_grad=1.0, + lr=lr, + weight_decay=0.01, + min_lr=1e-6 if lr > 1e-4 else 1e-7, + ) + + def _create_scheduler_config(self, total_iters): + """Create a scheduler configuration.""" + return SchedulerConfig( + start_weight_decay=0.033, + end_weight_decay=0.033, + weight_decay_incr_style="constant", + lr_decay_style="cosine", + lr_warmup_iters=2 if total_iters >= 10 else 1, + lr_warmup_init=0.0, + lr_decay_iters=total_iters, + ) + + def _create_ddp_config(self): + """Create a DDP configuration.""" + return DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ) + + def _create_mock_dataset_config(self, seq_length, seed=1234): + """Create a mock dataset configuration.""" + return MockGPTDatasetConfig( + random_seed=seed, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + sequence_length=seq_length, + num_dataset_builder_threads=1, + data_sharding=True, + dataloader_type="single", + num_workers=1, + ) + + def _create_squad_dataset_config(self, seq_length, seed=5678): + """Create a SQuAD dataset configuration.""" + return HFDatasetConfig( + dataset_name="squad", + process_example_fn=process_squad_example, + seq_length=seq_length, + seed=seed, + dataloader_type="single", + num_workers=1, + do_validation=False, + do_test=False, + val_proportion=None, + rewrite=False, + ) + + def _create_pretrain_tokenizer_config(self): + """Create a tokenizer configuration for pretraining.""" + return TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=9999, + ) + + def _create_finetune_tokenizer_config(self): + """Create a tokenizer configuration for finetuning.""" + return TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model="gpt2", + ) + + def _create_logger_config(self, tensorboard_dir): + """Create a logger configuration.""" + return LoggerConfig( + log_interval=5, + tensorboard_dir=tensorboard_dir, + ) + + def _create_checkpoint_config(self, save_interval, save_dir, pretrained_checkpoint=None, load_dir=None): + """Create a checkpoint configuration.""" + return CheckpointConfig( + save_interval=save_interval, + save=save_dir, + pretrained_checkpoint=pretrained_checkpoint, + load=load_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + save_optim=False, + save_rng=False, + ) + + def _create_rng_config(self, seed=1234): + """Create an RNG configuration.""" + return RNGConfig(seed=seed) + + def _create_lora_peft(self, dim=16, alpha=32, dropout=0.1): + """Create a LoRA PEFT configuration.""" + return LoRA( + target_modules=["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"], + dim=dim, + alpha=alpha, + dropout=dropout, + ) + + def _create_pretrain_config(self, train_iters, checkpoint_dir, tensorboard_dir, seq_length=512): + """Create complete pretrain configuration with model.""" + model = self._create_model_provider(seq_length) + + return ConfigContainer( + model=model, + train=self._create_training_config(train_iters), + optimizer=self._create_optimizer_config(), + scheduler=self._create_scheduler_config(train_iters), + ddp=self._create_ddp_config(), + dataset=self._create_mock_dataset_config(seq_length), + logger=self._create_logger_config(tensorboard_dir), + tokenizer=self._create_pretrain_tokenizer_config(), + checkpoint=self._create_checkpoint_config(train_iters, checkpoint_dir), + rng=self._create_rng_config(), + ) + + def _create_lora_config( + self, train_iters, checkpoint_dir, tensorboard_dir, pretrained_checkpoint_dir, seq_length=512 + ): + """Create complete LoRA finetuning configuration with model and PEFT.""" + model = self._create_model_provider(seq_length) + lora_peft = self._create_lora_peft() + + return ConfigContainer( + model=model, + train=self._create_training_config(train_iters), + optimizer=self._create_optimizer_config(lr=1e-4), # Lower LR for finetuning + scheduler=self._create_scheduler_config(train_iters), + ddp=self._create_ddp_config(), + dataset=self._create_squad_dataset_config(seq_length), + logger=self._create_logger_config(tensorboard_dir), + tokenizer=self._create_finetune_tokenizer_config(), + checkpoint=self._create_checkpoint_config(train_iters, checkpoint_dir, pretrained_checkpoint_dir), + rng=self._create_rng_config(seed=5678), + peft=lora_peft, + ) + + def _setup_directories(self, base_dir, suffix=""): + """Setup test directories.""" + pretrain_checkpoint_dir = os.path.join(base_dir, f"pretrain_checkpoints{suffix}") + pretrain_tensorboard_dir = os.path.join(base_dir, f"pretrain_tensorboard{suffix}") + lora_checkpoint_dir = os.path.join(base_dir, f"lora_checkpoints{suffix}") + lora_tensorboard_dir = os.path.join(base_dir, f"lora_tensorboard{suffix}") + merged_checkpoint_dir = os.path.join(base_dir, f"merged_checkpoint{suffix}") + + if torch.distributed.get_rank() == 0: + os.makedirs(pretrain_checkpoint_dir, exist_ok=True) + os.makedirs(pretrain_tensorboard_dir, exist_ok=True) + os.makedirs(lora_checkpoint_dir, exist_ok=True) + os.makedirs(lora_tensorboard_dir, exist_ok=True) + os.makedirs(merged_checkpoint_dir, exist_ok=True) + + return ( + pretrain_checkpoint_dir, + pretrain_tensorboard_dir, + lora_checkpoint_dir, + lora_tensorboard_dir, + merged_checkpoint_dir, + ) diff --git a/tests/unit_tests/peft/test_lora.py b/tests/unit_tests/peft/test_lora.py index 4c5c3e3652..01ff1b73c8 100644 --- a/tests/unit_tests/peft/test_lora.py +++ b/tests/unit_tests/peft/test_lora.py @@ -24,8 +24,9 @@ from megatron.core.transformer.module import MegatronModule from megatron.bridge.models.gpt_provider import GPTModelProvider -from megatron.bridge.peft.lora import LoRA, LoRAMerge +from megatron.bridge.peft.lora import LoRA from megatron.bridge.peft.lora_layers import LinearAdapter, LoRALinear +from megatron.bridge.peft.lora_merge import LoRAMerge class SimpleModel(nn.Module): diff --git a/tests/unit_tests/peft/test_lora_merge.py b/tests/unit_tests/peft/test_lora_merge.py new file mode 100644 index 0000000000..7ec98d5de4 --- /dev/null +++ b/tests/unit_tests/peft/test_lora_merge.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest.mock as mock + +import pytest + +from megatron.bridge.peft.lora_merge import ( + _load_full_config_container_from_checkpoint, + _modify_checkpoint_config_for_merge, + _prepare_merged_config, +) +from megatron.bridge.training.config import CheckpointConfig, ConfigContainer + + +class TestLoRAMerge: + """Unit tests for LoRA merge functionality.""" + + def test_modify_checkpoint_config_for_merge(self): + """Test checkpoint config modification for merge operation.""" + original_config = CheckpointConfig( + save="/original/save/path", + load="/original/load/path", + pretrained_checkpoint="/pretrained/path", + finetune=True, + load_optim=True, + load_rng=True, + ckpt_format="torch_dist", + async_save=True, + ) + + lora_checkpoint_path = "/path/to/lora/checkpoint" + + modified_config = _modify_checkpoint_config_for_merge(original_config, lora_checkpoint_path) + + assert modified_config.load_optim is False + assert modified_config.load_rng is False + + def test_prepare_merged_config(self): + """Test preparation of config for merged checkpoint.""" + # Create a mock original config + original_checkpoint_config = CheckpointConfig( + save="/original/save", + load="/lora/checkpoint", + pretrained_checkpoint="/base/model", + finetune=False, + load_optim=False, + load_rng=False, + ) + + # Create a simple mock config container + original_config = mock.MagicMock(spec=ConfigContainer) + original_config.checkpoint = original_checkpoint_config + + output_path = "/merged/output" + + with mock.patch("megatron.bridge.peft.lora_merge.replace") as mock_replace: + # Mock the replace calls + mock_replace.side_effect = lambda obj, **kwargs: mock.MagicMock( + **{**{attr: getattr(obj, attr) for attr in dir(obj) if not attr.startswith("_")}, **kwargs} + ) + + _ = _prepare_merged_config(original_config, output_path) + + # Verify replace was called correctly for checkpoint config + checkpoint_calls = [ + call + for call in mock_replace.call_args_list + if "save" in call.kwargs or "pretrained_checkpoint" in call.kwargs + ] + assert len(checkpoint_calls) >= 1 + + # Check that expected modifications were requested + found_checkpoint_update = False + for call in checkpoint_calls: + if call.kwargs.get("save") == output_path: + found_checkpoint_update = True + assert call.kwargs.get("pretrained_checkpoint") is None + assert call.kwargs.get("ckpt_format") == "torch_dist" + + assert found_checkpoint_update, "Checkpoint config not updated correctly" + + @mock.patch("megatron.bridge.training.utils.checkpoint_utils.get_checkpoint_run_config_filename") + @mock.patch("megatron.bridge.training.utils.checkpoint_utils.file_exists") + @mock.patch("megatron.bridge.training.utils.checkpoint_utils.read_run_config") + @mock.patch("megatron.bridge.training.config.ConfigContainer.from_dict") + def test_load_full_config_container_from_checkpoint( + self, mock_from_dict, mock_read_config, mock_file_exists, mock_get_filename + ): + """Test loading full config container from checkpoint.""" + # Setup mocks + lora_checkpoint_path = "/path/to/lora/checkpoint" + config_filename = "/path/to/lora/checkpoint/run_config.yaml" + + mock_get_filename.return_value = config_filename + mock_file_exists.return_value = True + + # Mock run config data + mock_run_config = { + "_target_": "megatron.bridge.training.config.ConfigContainer", + "train": {"_target_": "TrainingConfig", "train_iters": 100}, + "model": {"_target_": "GPTModelProvider", "num_layers": 12}, + "checkpoint": {"_target_": "CheckpointConfig", "save_interval": 100}, + "peft": {"_target_": "LoRA", "dim": 16}, + } + mock_read_config.return_value = mock_run_config + + # Mock ConfigContainer.from_dict to return a mock object with real CheckpointConfig + mock_config_container = mock.MagicMock() + mock_config_container.checkpoint = CheckpointConfig( + save="/original/save", + load="/original/load", + save_optim=True, + load_optim=True, + load_rng=True, + ) + mock_from_dict.return_value = mock_config_container + + # Call the function + result = _load_full_config_container_from_checkpoint(lora_checkpoint_path) + + # Verify file operations + mock_get_filename.assert_called_once_with(lora_checkpoint_path) + mock_file_exists.assert_called_once_with(config_filename) + mock_read_config.assert_called_once_with(config_filename) + + # Verify ConfigContainer.from_dict was called with LENIENT mode + mock_from_dict.assert_called_once() + + # Verify result is what we expect + assert result is mock_config_container + + @mock.patch("megatron.bridge.training.utils.checkpoint_utils.file_exists") + def test_load_config_missing_file(self, mock_file_exists): + """Test error handling when run_config.yaml is missing.""" + mock_file_exists.return_value = False + + with pytest.raises(ValueError, match="Run config not found"): + _load_full_config_container_from_checkpoint("/nonexistent/checkpoint")