Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions verl/utils/megatron_peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,94 @@ def print_adapter_info(model):
print(f"{'=' * 60}\n")


def maybe_enable_recompute_inputs_grad(model, peft_recompute_patched: set) -> set:
"""Enable grad on TransformerBlock inputs when only adapters are trainable.

Root cause analysis:
- Megatron's CheckpointFunction.backward() is only invoked by PyTorch autograd
when at least one input tensor requires grad.
- With PP>1, received tensors from other stages have requires_grad=True, so
checkpoint backward is always called.
- With PP=1 and frozen base model, embedding outputs have requires_grad=False.
This means CheckpointFunction.backward() is never called, and LoRA gradients
inside the checkpoint are never computed.

Solution: Hook TransformerBlock.forward to ensure hidden_states.requires_grad=True
before it enters checkpointed computation. This doesn't unfreeze any parameters;
it just ensures the autograd machinery calls checkpoint's backward.
"""
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.utils import (
unwrap_model,
)

try:
unwrapped_model = unwrap_model(model)

cfg = getattr(unwrapped_model, "config", None)

# Only needed when activation recompute is enabled
if cfg is None or getattr(cfg, "recompute_method", None) is None:
return peft_recompute_patched

# Avoid duplicate patches
if id(unwrapped_model) in peft_recompute_patched:
return peft_recompute_patched

# Check if only adapters are trainable (frozen base)
params = list(unwrapped_model.named_parameters())
trainable_adapter = any(p.requires_grad and ".adapter." in n.lower() for n, p in params)
trainable_base = any(
p.requires_grad and (".to_wrap." not in n.lower() and ".adapter." not in n.lower()) for n, p in params
)

if not (trainable_adapter and not trainable_base):
return peft_recompute_patched # Not adapter-only training, no fix needed

# Find TransformerBlock(s) in the model and patch their forward
def _patch_transformer_block(module):
if isinstance(module, TransformerBlock):
original_forward = module.forward

def patched_forward(hidden_states, *args, **kwargs):
# Ensure hidden_states requires grad so checkpoint backward is called
if (
torch.is_tensor(hidden_states)
and not hidden_states.requires_grad
and hidden_states.is_floating_point()
):
hidden_states = hidden_states.detach().requires_grad_(True)
return original_forward(hidden_states, *args, **kwargs)

module.forward = patched_forward
return True
return False

patched = False
for module in unwrapped_model.modules():
if _patch_transformer_block(module):
patched = True

if patched:
peft_recompute_patched.add(id(unwrapped_model))
print(
"[PEFT+Recompute] Patched TransformerBlock.forward to enable grad on "
"hidden_states input. This ensures checkpoint backward is called when "
"only adapters are trainable (PP=1 with frozen base model).",
flush=True,
)
except Exception as e:
# Log but don't fail - user will see grad_norm=0 and can debug
print(f"[PEFT+Recompute] Warning: Failed to patch TransformerBlock: {e}", flush=True)

return peft_recompute_patched


__all__ = [
"get_adapter_state_dict",
"save_adapter_checkpoint",
"load_adapter_checkpoint",
"count_adapter_parameters",
"print_adapter_info",
"maybe_enable_recompute_inputs_grad",
]
9 changes: 9 additions & 0 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
set_router_replay_data,
)
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
from verl.utils.megatron_peft_utils import maybe_enable_recompute_inputs_grad
from verl.utils.megatron_utils import get_model_config, unwrap_model
from verl.utils.profiler import GPUMemoryLogger
from verl.utils.profiler.profile import Profiler
Expand Down Expand Up @@ -167,6 +168,9 @@ def __init__(
print(config)
config.finalize_model_grads_func = finalize_model_grads

# Track models patched for PEFT + recompute compatibility
self._peft_recompute_patched = set()

def _validate_config(self, config) -> None:
"""Validate config options not implemented for Megatron backend"""
assert config.get("ulysses_sequence_parallel_size", 1) == 1
Expand Down Expand Up @@ -738,6 +742,11 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict:

"""
metrics = {}

# Apply PEFT + recompute compatibility patch once before training starts
for actor_model in self.actor_module:
self._peft_recompute_patched = maybe_enable_recompute_inputs_grad(actor_model, self._peft_recompute_patched)

if self.use_torch_profiler and self.prof and self.prof.enable:
self.prof.start()
for data in dataloader:
Expand Down
10 changes: 10 additions & 0 deletions verl/workers/critic/megatron_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from verl.trainer.ppo import core_algos
from verl.utils.device import get_device_id, get_torch_device
from verl.utils.megatron.pipeline_parallel import make_batch_generator
from verl.utils.megatron_peft_utils import maybe_enable_recompute_inputs_grad
from verl.utils.profiler import GPUMemoryLogger
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
Expand Down Expand Up @@ -78,6 +79,9 @@ def __init__(
}
)

# Track models patched for PEFT + recompute compatibility
self._peft_recompute_patched = set()

def _validate_config(self, config) -> None:
"""Validate config options not implemented for Megatron backend"""
assert config.get("ulysses_sequence_parallel_size", 1) == 1
Expand Down Expand Up @@ -296,6 +300,12 @@ def forward_step(batch_iter, model):
def update_critic(self, dataloader: Iterable[DataProto]):
metrics = {}

# Apply PEFT + recompute compatibility patch once before training starts
for critic_model in self.critic_module:
self._peft_recompute_patched = maybe_enable_recompute_inputs_grad(
critic_model, self._peft_recompute_patched
)

for data in dataloader:
self.critic_optimizer.zero_grad()
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
Expand Down
Loading