Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/e2e_ppo_trainer_megatron_vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ jobs:
- name: clean up and install Megatron-Bridge
run: |
rm -rf checkpoints
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@6259ae8 --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA/Megatron-LM.git@7ca9dc5 --no-deps --no-build-isolation
pip3 install "nvidia-modelopt[torch]>=0.37.0" transformers==4.57.1
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron, use Megatron-Bridge LoRA e2e to pre-load and save (Deepseek)
run: |
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/e2e_ppo_trainer_megatron_vllm_2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ jobs:
run: |
pip3 install -r requirements-test.txt
pip3 install --no-deps --force-reinstall .
pip3 install megatron-bridge --no-deps
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@6259ae8 --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA/Megatron-LM.git@7ca9dc5 --no-deps --no-build-isolation
pip3 install "nvidia-modelopt[torch]>=0.37.0" transformers==4.57.1
- name: Prepare GSM8K dataset
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/advance/ppo_lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Megatron Backend Usage Guide

You need to install and enable Megatron-Bridge for Megatron LoRA support.

Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `this commit <https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/83a7c1134c562d8c6decd10a1f0a6e6a7a8a3a44>`_ or later for proper support, and use the following settings to enable Megatron-Bridge:
Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `this commit <https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/6259ae83c735c4412796fc5cfb4c9607b949ae29>`_ or later for proper support, and use the following settings to enable Megatron-Bridge:

- ``actor_rollout_ref.actor.megatron.use_mbridge=True``
- ``actor_rollout_ref.actor.megatron.vanilla_mbridge=False``
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -xeuo pipefail

# Need to install Megatron-Bridge
# NOTE: Make sure you use Megatron-Bridge later than 0.2.0
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/83a7c1134c562d8c6decd10a1f0a6e6a7a8a3a44 or later)
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/6259ae83c735c4412796fc5cfb4c9607b949ae29 or later)
# for proper MoE LoRA support.

# For Megatron communication/computation overlapping
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -xeuo pipefail

# Need to install Megatron-Bridge
# NOTE: Make sure you use Megatron-Bridge later than 0.2.0
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/83a7c1134c562d8c6decd10a1f0a6e6a7a8a3a44 or later)
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/6259ae83c735c4412796fc5cfb4c9607b949ae29 or later)
# for proper MoE LoRA support.

# For Megatron communication/computation overlapping
Expand Down
151 changes: 85 additions & 66 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,23 +259,25 @@ def generate_state_dict(
state_dict = {}
base_metadata = metadata or self._build_sharded_state_dict_metadata()

# Should always generate model state dict
# All ranks Save Model to reduce memory pressure
# Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure
for vpp_rank, model in enumerate(self.model):
if len(self.model) > 1:
mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)
key = f"model{vpp_rank}" if len(self.model) > 1 else "model"
else:
key = "model"
if hasattr(model, "module"):
model = model.module
should_generate_model_sections = generate_model or generate_optimizer

# GPTModel's sharded_state_dict function when having mtp requires metadata['dp_cp_group']
model_metadata = dict(base_metadata)
model_metadata["dp_cp_group"] = mpu.get_data_parallel_group(with_context_parallel=True)
kwargs = {"metadata": model_metadata}
state_dict[key] = model.sharded_state_dict(**kwargs)
# All ranks save model state dict when it is needed for either model checkpointing
# or optimizer sharded_state_dict generation.
if should_generate_model_sections:
for vpp_rank, model in enumerate(self.model):
if len(self.model) > 1:
mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)
key = f"model{vpp_rank}" if len(self.model) > 1 else "model"
else:
key = "model"
if hasattr(model, "module"):
model = model.module

# GPTModel's sharded_state_dict function when having mtp requires metadata['dp_cp_group']
model_metadata = dict(base_metadata)
model_metadata["dp_cp_group"] = mpu.get_data_parallel_group(with_context_parallel=True)
kwargs = {"metadata": model_metadata}
state_dict[key] = model.sharded_state_dict(**kwargs)

# Optimizer State Dict
if generate_optimizer:
Expand All @@ -293,7 +295,9 @@ def generate_state_dict(
state_dict["lr_scheduler"] = lr_state_dict

if not generate_model:
state_dict.pop("model", None)
for key in list(state_dict.keys()):
if self._is_model_state_key(key):
state_dict.pop(key)

# RNG States State Dict
if generate_extra:
Expand Down Expand Up @@ -341,6 +345,43 @@ def _build_sharded_state_dict_metadata(self) -> dict:
metadata["chained_optim_avoid_prefix"] = True
return metadata

@staticmethod
def _is_model_state_key(key: str) -> bool:
return key == "model" or (key.startswith("model") and key[5:].isdigit())

@staticmethod
def _has_checkpoint_files(path: str) -> bool:
return os.path.isdir(path) and any(os.scandir(path))

def _raise_for_unsupported_peft_checkpoint_layout(self, local_path: str, dist_checkpoint_path: str):
if self.peft_cls is None or not self.should_load_model or self._has_checkpoint_files(dist_checkpoint_path):
return

legacy_adapter_ckpt_path = os.path.join(local_path, "adapter_checkpoint")
hf_adapter_ckpt_path = os.path.join(local_path, "huggingface", "adapter")

if os.path.isdir(legacy_adapter_ckpt_path):
raise RuntimeError(
f"Found legacy PEFT checkpoint at {legacy_adapter_ckpt_path}, but checkpoint resume now expects "
f"adapter weights in {dist_checkpoint_path}. Resave/convert the checkpoint or load the adapter via "
"`lora.adapter_path`."
)

if os.path.isfile(os.path.join(hf_adapter_ckpt_path, "adapter_config.json")):
raise RuntimeError(
f"Found exported HF PEFT adapter at {hf_adapter_ckpt_path}, but `load_checkpoint()` resumes from "
f"{dist_checkpoint_path}. HF adapter exports are not used for trainer resume; keep the distributed "
"checkpoint or load the adapter separately via `lora.adapter_path`."
)

def _maybe_filter_peft_state_dict(self, state_dict: dict):
if self.peft_cls is None:
return state_dict

from megatron.bridge.training.checkpointing import apply_peft_adapter_filter_to_state_dict

return apply_peft_adapter_filter_to_state_dict(state_dict, self.peft_cls)

def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True):
# access rng_state for data parallel rank
if data_parallel_random_init:
Expand Down Expand Up @@ -373,6 +414,7 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
pass

dist_checkpoint_path = get_dist_checkpoint_path(local_path)
self._raise_for_unsupported_peft_checkpoint_layout(local_path, dist_checkpoint_path)

load_content_metadata = getattr(dist_checkpointing, "load_content_metadata", None)
if load_content_metadata is None:
Expand All @@ -392,13 +434,15 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
sharded_sd_metadata = self._build_sharded_state_dict_metadata()

# Get State Dict for loading
should_load_dist_model = self.should_load_model and (self.use_dist_checkpointing or self.peft_cls is not None)
sharded_state_dict = self.generate_state_dict(
self.should_load_model and self.use_dist_checkpointing,
should_load_dist_model,
self.should_load_optimizer,
self.should_load_extra,
is_loading=True,
metadata=sharded_sd_metadata,
)
sharded_state_dict = self._maybe_filter_peft_state_dict(sharded_state_dict)
log_with_rank(f"Generated state dict for loading: {sharded_state_dict.keys()}", rank=self.rank, logger=logger)

# Load Dist Checkpointing
Expand All @@ -407,7 +451,7 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
ckpt_dir=dist_checkpoint_path,
)

if self.should_load_model and self.use_dist_checkpointing:
if should_load_dist_model:
assert "model" in state_dict or any(
f"model{vpp_rank}" in state_dict for vpp_rank in range(len(self.model))
), f"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}."
Expand All @@ -418,8 +462,13 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
assert f"model{vpp_rank}" in state_dict, f"model{vpp_rank} not found in state_dict"
model_state_dict = state_dict[f"model{vpp_rank}"]
mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)
self.model[vpp_rank].load_state_dict(model_state_dict)
log_with_rank(f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger)
self.model[vpp_rank].load_state_dict(model_state_dict, strict=self.peft_cls is None)
if self.peft_cls is not None:
log_with_rank(
f"Loaded PEFT adapter checkpoint from {dist_checkpoint_path}", rank=self.rank, logger=logger
)
else:
log_with_rank(f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger)

# Skip HF checkpoint loading if PEFT is used
elif self.should_load_model and self.use_hf_checkpoint and self.peft_cls is None:
Expand All @@ -429,29 +478,6 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
else:
self.bridge.load_hf_weights(self.model, hf_model_path)
log_with_rank(f"Loaded HF model checkpoint from {hf_model_path} with bridge", rank=self.rank, logger=logger)
# Load PEFT adapter checkpoint if available
if self.should_load_model and self.peft_cls is not None:
adapter_ckpt_path = os.path.join(local_path, "adapter_checkpoint")
if os.path.exists(adapter_ckpt_path):
from verl.utils.megatron_peft_utils import load_adapter_checkpoint

# TODO: a better format for adapter checkpoint, waiting megatron-bridge support

load_adapter_checkpoint(
self.model,
adapter_ckpt_path,
)
log_with_rank(
f"Loaded adapter checkpoint from {adapter_ckpt_path}",
rank=self.rank,
logger=logger,
)
else:
log_with_rank(
f"PEFT config is set but no adapter checkpoint found at {adapter_ckpt_path}",
rank=self.rank,
logger=logger,
)

if self.should_load_optimizer:
assert "optimizer" in state_dict, (
Expand Down Expand Up @@ -509,6 +535,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
self.should_save_extra,
metadata=sharded_sd_metadata,
)
state_dict = self._maybe_filter_peft_state_dict(state_dict)
log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger)
for vpp_rank, model in enumerate(self.model):
if len(self.model) > 1:
Expand All @@ -535,11 +562,12 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
# Generate optimizer and exra state dicts
sharded_sd_metadata = self._build_sharded_state_dict_metadata()
state_dict = self.generate_state_dict(
generate_model=False,
generate_model=self.should_save_model and self.peft_cls is not None,
generate_optimizer=self.should_save_optimizer,
generate_extra=self.should_save_extra,
metadata=sharded_sd_metadata,
)
state_dict = self._maybe_filter_peft_state_dict(state_dict)
# Save optimizer and extra states to local path
# Start Async save if enabled
async_save_request = save_dist_checkpointing(
Expand All @@ -555,26 +583,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
torch.distributed.barrier()

if self.should_save_model:
# Save adapter-only checkpoint if PEFT is enabled
if self.peft_cls is not None:
from verl.utils.megatron_peft_utils import save_adapter_checkpoint

adapter_ckpt_path = os.path.join(local_path, "adapter_checkpoint")

# Save adapter weights only (much smaller than full model)
save_adapter_checkpoint(
self.model,
adapter_ckpt_path,
self.rank,
)

log_with_rank(
f"Saved adapter-only checkpoint to {adapter_ckpt_path}",
rank=self.rank,
logger=logger,
log_only_rank_0=True,
)
elif self.use_hf_checkpoint:
if self.use_hf_checkpoint:
# Use mbridge to save HF model checkpoint
log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger)
hf_ckpt_path = get_hf_model_checkpoint_path(local_path)
Expand All @@ -588,7 +597,17 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
extended_args[sig] = mbridge_config[sig]
self.bridge.save_weights(self.model, hf_ckpt_path, **extended_args)
else:
self.bridge.save_hf_weights(self.model, hf_ckpt_path)
if self.peft_cls is not None:
hf_adapter_ckpt_path = os.path.join(hf_ckpt_path, "adapter")
self.bridge.save_hf_adapter(self.model, hf_adapter_ckpt_path, self.peft_cls)
log_with_rank(
f"Saved HF PEFT adapter checkpoint to {hf_adapter_ckpt_path}",
rank=self.rank,
logger=logger,
log_only_rank_0=True,
)
else:
self.bridge.save_hf_weights(self.model, hf_ckpt_path)

log_with_rank(f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger)

Expand Down
Loading
Loading