diff --git a/src/megatron/bridge/peft/base.py b/src/megatron/bridge/peft/base.py index f7d18e441f..2379f69f56 100644 --- a/src/megatron/bridge/peft/base.py +++ b/src/megatron/bridge/peft/base.py @@ -21,6 +21,7 @@ import torch.nn as nn from megatron.core.transformer.module import MegatronModule +from megatron.bridge.peft.recompute import maybe_enable_recompute_inputs_grad from megatron.bridge.peft.walk_utils import walk @@ -107,6 +108,9 @@ def __call__(self, model: ModelType, training: bool = True) -> ModelType: model_to_walk = model walk(model_to_walk, self.transform) + if training: + maybe_enable_recompute_inputs_grad(model) + if not training: self.freeze_model(model, training=training) diff --git a/src/megatron/bridge/peft/recompute.py b/src/megatron/bridge/peft/recompute.py new file mode 100644 index 0000000000..82aea96884 --- /dev/null +++ b/src/megatron/bridge/peft/recompute.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for PEFT-specific activation recompute fixes.""" + +from __future__ import annotations + +from functools import wraps +from typing import Iterable, Set + +import torch +from megatron.core.utils import unwrap_model + +from megatron.bridge.utils.common_utils import print_rank_0 + + +PEFT_RECOMPUTE_PATCHED: Set[int] = set() + + +def _iter_unwrapped_models(model) -> Iterable[torch.nn.Module]: + """Yield unwrapped Megatron modules regardless of list/list-like inputs.""" + unwrapped = unwrap_model(model) + if isinstance(unwrapped, list): + for module in unwrapped: + if module is not None: + yield module + else: + if unwrapped is not None: + yield unwrapped + + +def maybe_enable_recompute_inputs_grad(model, peft_recompute_patched: Set[int] | None = None) -> Set[int]: + """Enable grad on TransformerBlock inputs when only adapters are trainable. + + Root cause analysis: + + - Megatron's CheckpointFunction.backward() is only invoked by PyTorch autograd + when at least one input tensor requires grad. + - With PP>1, received tensors from other stages have requires_grad=True, so + checkpoint backward is always called. + - With PP=1 and frozen base model, embedding outputs have requires_grad=False. + This means CheckpointFunction.backward() is never called, and LoRA gradients + inside the checkpoint are never computed. + + Solution: Hook TransformerBlock.forward to ensure hidden_states.requires_grad=True + before it enters checkpointed computation. This doesn't unfreeze any parameters; + it just ensures the autograd machinery calls checkpoint's backward. + + Borrowed (with modifications) from + https://github.com/HollowMan6/verl/blob/4285f0601028aee7ddcb9ec5a15198ebfc69bba3/verl/utils/megatron_peft_utils.py + """ + + from megatron.core.transformer.transformer_block import TransformerBlock + + patched_registry = peft_recompute_patched or PEFT_RECOMPUTE_PATCHED + + try: + for unwrapped_model in _iter_unwrapped_models(model): + cfg = getattr(unwrapped_model, "config", None) + if cfg is None or getattr(cfg, "recompute_method", None) is None: + continue + + if id(unwrapped_model) in patched_registry: + continue + + params = list(unwrapped_model.named_parameters()) + trainable_adapter = any(p.requires_grad and ".adapter." in n.lower() for n, p in params) + trainable_base = any( + p.requires_grad and (".to_wrap." not in n.lower() and ".adapter." not in n.lower()) for n, p in params + ) + + if not (trainable_adapter and not trainable_base): + continue # Not adapter-only training, no fix needed + + def _patch_transformer_block(module: torch.nn.Module) -> bool: + if isinstance(module, TransformerBlock): + original_forward = module.forward + + @wraps(original_forward) + def patched_forward(hidden_states, *args, _original_forward=original_forward, **kwargs): + # Ensure hidden_states requires grad so checkpoint backward is called + if ( + torch.is_tensor(hidden_states) + and not hidden_states.requires_grad + and hidden_states.is_floating_point() + ): + hidden_states = hidden_states.detach().requires_grad_(True) + return _original_forward(hidden_states, *args, **kwargs) + + module.forward = patched_forward + return True + return False + + patched = False + for module in unwrapped_model.modules(): + if _patch_transformer_block(module): + patched = True + if patched: + patched_registry.add(id(unwrapped_model)) + print_rank_0( + "[PEFT+Recompute] Patched TransformerBlock.forward to enable grad on " + "hidden_states input. This ensures checkpoint backward is called when " + "only adapters are trainable (PP=1 with frozen base model).", + ) + except Exception as exc: # pragma: no cover - best effort logging + # Log but don't fail - user will see grad_norm=0 and can debug + print_rank_0(f"[PEFT+Recompute] Warning: Failed to patch TransformerBlock: {exc}") + + return patched_registry + + +__all__ = ["maybe_enable_recompute_inputs_grad", "PEFT_RECOMPUTE_PATCHED"] diff --git a/tests/unit_tests/peft/test_recompute.py b/tests/unit_tests/peft/test_recompute.py new file mode 100644 index 0000000000..3742e36ee7 --- /dev/null +++ b/tests/unit_tests/peft/test_recompute.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for PEFT-specific recompute helpers.""" + +from types import SimpleNamespace + +import torch + +from megatron.bridge.peft import recompute as recompute_mod +from megatron.bridge.peft.recompute import maybe_enable_recompute_inputs_grad + + +class DummyAdapter(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(1)) + + +class DummyTransformerBlock(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.last_input_requires_grad = None + + def forward(self, hidden_states, *args, **kwargs): + self.last_input_requires_grad = hidden_states.requires_grad + return hidden_states + + +class DummyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.config = SimpleNamespace(recompute_method="uniform") + self.block = DummyTransformerBlock() + + # Frozen base parameter (not trainable) + self.base = torch.nn.Linear(1, 1, bias=False) + self.base.weight.requires_grad = False + + # Trainable adapter parameter whose name contains ".adapter." + # Use a ModuleDict with key "adapter" so that the full parameter + # name includes the expected substring (".adapter.") used by + # maybe_enable_recompute_inputs_grad. + self.adapter = torch.nn.ModuleDict({"adapter": DummyAdapter()}) + + def modules(self): + for module in super().modules(): + yield module + + +def _patch_transformer_block(monkeypatch): + import megatron.core.transformer.transformer_block as transformer_block + + monkeypatch.setattr( + transformer_block, + "TransformerBlock", + DummyTransformerBlock, + raising=False, + ) + + +def test_maybe_enable_recompute_inputs_grad_patches_block(monkeypatch): + _patch_transformer_block(monkeypatch) + recompute_mod.PEFT_RECOMPUTE_PATCHED.clear() + + model = DummyModel() + patched_registry = maybe_enable_recompute_inputs_grad(model, set()) + + assert id(model) in patched_registry + + patched_forward = model.block.forward + + input_tensor = torch.zeros(2, 2) + assert input_tensor.requires_grad is False + + model.block(input_tensor) + assert model.block.last_input_requires_grad is True + + # Second invocation should be a no-op (no duplicate patch) + maybe_enable_recompute_inputs_grad(model, patched_registry) + assert model.block.forward is patched_forward