From 310926f2c75881d101672d88d647f396b1c6cd93 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Wed, 17 Dec 2025 14:47:00 -0800 Subject: [PATCH 1/4] Fix adapter-only recompute hook --- src/megatron/bridge/peft/recompute.py | 122 ++++++++++++++++++++++++ src/megatron/bridge/training/setup.py | 2 + tests/unit_tests/peft/test_recompute.py | 91 ++++++++++++++++++ 3 files changed, 215 insertions(+) create mode 100644 src/megatron/bridge/peft/recompute.py create mode 100644 tests/unit_tests/peft/test_recompute.py diff --git a/src/megatron/bridge/peft/recompute.py b/src/megatron/bridge/peft/recompute.py new file mode 100644 index 0000000000..ca76ab2310 --- /dev/null +++ b/src/megatron/bridge/peft/recompute.py @@ -0,0 +1,122 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for PEFT-specific activation recompute fixes.""" + +from __future__ import annotations + +from functools import wraps +from typing import Iterable, Set + +import torch +from megatron.core.utils import unwrap_model + +PEFT_RECOMPUTE_PATCHED: Set[int] = set() + + +def _iter_unwrapped_models(model) -> Iterable[torch.nn.Module]: + """Yield unwrapped Megatron modules regardless of list/list-like inputs.""" + unwrapped = unwrap_model(model) + if isinstance(unwrapped, list): + for module in unwrapped: + if module is not None: + yield module + else: + if unwrapped is not None: + yield unwrapped + + +def maybe_enable_recompute_inputs_grad(model, peft_recompute_patched: Set[int] | None = None) -> Set[int]: + """Enable grad on TransformerBlock inputs when only adapters are trainable. + + Root cause analysis: + + - Megatron's CheckpointFunction.backward() is only invoked by PyTorch autograd + when at least one input tensor requires grad. + - With PP>1, received tensors from other stages have requires_grad=True, so + checkpoint backward is always called. + - With PP=1 and frozen base model, embedding outputs have requires_grad=False. + This means CheckpointFunction.backward() is never called, and LoRA gradients + inside the checkpoint are never computed. + + Solution: Hook TransformerBlock.forward to ensure hidden_states.requires_grad=True + before it enters checkpointed computation. This doesn't unfreeze any parameters; + it just ensures the autograd machinery calls checkpoint's backward. + + Borrowed (with modifications) from + https://github.com/HollowMan6/verl/blob/4285f0601028aee7ddcb9ec5a15198ebfc69bba3/verl/utils/megatron_peft_utils.py + """ + + from megatron.core.transformer.transformer_block import TransformerBlock + + patched_registry = peft_recompute_patched or PEFT_RECOMPUTE_PATCHED + + try: + for unwrapped_model in _iter_unwrapped_models(model): + cfg = getattr(unwrapped_model, "config", None) + if cfg is None or getattr(cfg, "recompute_method", None) is None: + continue + + if id(unwrapped_model) in patched_registry: + continue + + params = list(unwrapped_model.named_parameters()) + trainable_adapter = any(p.requires_grad and ".adapter." in n.lower() for n, p in params) + trainable_base = any( + p.requires_grad and (".to_wrap." not in n.lower() and ".adapter." not in n.lower()) for n, p in params + ) + + if not (trainable_adapter and not trainable_base): + continue # Not adapter-only training, no fix needed + + def _patch_transformer_block(module: torch.nn.Module) -> bool: + if isinstance(module, TransformerBlock): + original_forward = module.forward + + @wraps(original_forward) + def patched_forward(hidden_states, *args, _original_forward=original_forward, **kwargs): + # Ensure hidden_states requires grad so checkpoint backward is called + if ( + torch.is_tensor(hidden_states) + and not hidden_states.requires_grad + and hidden_states.is_floating_point() + ): + hidden_states = hidden_states.detach().requires_grad_(True) + return _original_forward(hidden_states, *args, **kwargs) + + module.forward = patched_forward + return True + return False + + patched = False + for module in unwrapped_model.modules(): + if _patch_transformer_block(module): + patched = True + if patched: + patched_registry.add(id(unwrapped_model)) + print( + "[PEFT+Recompute] Patched TransformerBlock.forward to enable grad on " + "hidden_states input. This ensures checkpoint backward is called when " + "only adapters are trainable (PP=1 with frozen base model).", + flush=True, + ) + except Exception as exc: # pragma: no cover - best effort logging + # Log but don't fail - user will see grad_norm=0 and can debug + print(f"[PEFT+Recompute] Warning: Failed to patch TransformerBlock: {exc}", flush=True) + + return patched_registry + + +__all__ = ["maybe_enable_recompute_inputs_grad", "PEFT_RECOMPUTE_PATCHED"] + diff --git a/src/megatron/bridge/training/setup.py b/src/megatron/bridge/training/setup.py index 6332f6793f..6b89713f98 100644 --- a/src/megatron/bridge/training/setup.py +++ b/src/megatron/bridge/training/setup.py @@ -49,6 +49,7 @@ finalize_tensor_inspect_post_model_initialization, initialize_tensor_inspect_pre_model_initialization, ) +from megatron.bridge.peft.recompute import maybe_enable_recompute_inputs_grad @@ -417,6 +418,7 @@ def _apply_peft_transformation(peft, base_model: list[MegatronModule]) -> list[M """ print_rank_0("Applying PEFT transformation...") transformed_model = peft(base_model, training=True) + maybe_enable_recompute_inputs_grad(transformed_model) peft.set_params_to_save(transformed_model) # Log PEFT statistics diff --git a/tests/unit_tests/peft/test_recompute.py b/tests/unit_tests/peft/test_recompute.py new file mode 100644 index 0000000000..d8497fa144 --- /dev/null +++ b/tests/unit_tests/peft/test_recompute.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for PEFT-specific recompute helpers.""" + +from types import SimpleNamespace + +import torch + +from megatron.bridge.peft import recompute as recompute_mod +from megatron.bridge.peft.recompute import maybe_enable_recompute_inputs_grad + + +class DummyAdapter(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(1)) + + +class DummyTransformerBlock(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.last_input_requires_grad = None + + def forward(self, hidden_states, *args, **kwargs): + self.last_input_requires_grad = hidden_states.requires_grad + return hidden_states + + +class DummyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.config = SimpleNamespace(recompute_method="uniform") + self.block = DummyTransformerBlock() + + # Frozen base parameter (not trainable) + self.base = torch.nn.Linear(1, 1, bias=False) + self.base.weight.requires_grad = False + + # Trainable adapter parameter whose name contains ".adapter." + self.adapter = DummyAdapter() + + def modules(self): + for module in super().modules(): + yield module + + +def _patch_transformer_block(monkeypatch): + import megatron.core.transformer.transformer_block as transformer_block + + monkeypatch.setattr( + transformer_block, + "TransformerBlock", + DummyTransformerBlock, + raising=False, + ) + + +def test_maybe_enable_recompute_inputs_grad_patches_block(monkeypatch): + _patch_transformer_block(monkeypatch) + recompute_mod.PEFT_RECOMPUTE_PATCHED.clear() + + model = DummyModel() + patched_registry = maybe_enable_recompute_inputs_grad(model, set()) + + assert id(model) in patched_registry + + patched_forward = model.block.forward + + input_tensor = torch.zeros(2, 2) + assert input_tensor.requires_grad is False + + model.block(input_tensor) + assert model.block.last_input_requires_grad is True + + # Second invocation should be a no-op (no duplicate patch) + maybe_enable_recompute_inputs_grad(model, patched_registry) + assert model.block.forward is patched_forward + + From 50a4c2a9a055e62926338c74ddb538b8f9df09f4 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Wed, 17 Dec 2025 14:53:09 -0800 Subject: [PATCH 2/4] lint Signed-off-by: yaoyu-33 --- src/megatron/bridge/peft/recompute.py | 9 +++++---- tests/unit_tests/peft/test_recompute.py | 2 -- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/megatron/bridge/peft/recompute.py b/src/megatron/bridge/peft/recompute.py index ca76ab2310..82aea96884 100644 --- a/src/megatron/bridge/peft/recompute.py +++ b/src/megatron/bridge/peft/recompute.py @@ -22,6 +22,9 @@ import torch from megatron.core.utils import unwrap_model +from megatron.bridge.utils.common_utils import print_rank_0 + + PEFT_RECOMPUTE_PATCHED: Set[int] = set() @@ -105,18 +108,16 @@ def patched_forward(hidden_states, *args, _original_forward=original_forward, ** patched = True if patched: patched_registry.add(id(unwrapped_model)) - print( + print_rank_0( "[PEFT+Recompute] Patched TransformerBlock.forward to enable grad on " "hidden_states input. This ensures checkpoint backward is called when " "only adapters are trainable (PP=1 with frozen base model).", - flush=True, ) except Exception as exc: # pragma: no cover - best effort logging # Log but don't fail - user will see grad_norm=0 and can debug - print(f"[PEFT+Recompute] Warning: Failed to patch TransformerBlock: {exc}", flush=True) + print_rank_0(f"[PEFT+Recompute] Warning: Failed to patch TransformerBlock: {exc}") return patched_registry __all__ = ["maybe_enable_recompute_inputs_grad", "PEFT_RECOMPUTE_PATCHED"] - diff --git a/tests/unit_tests/peft/test_recompute.py b/tests/unit_tests/peft/test_recompute.py index d8497fa144..ced33f3bc2 100644 --- a/tests/unit_tests/peft/test_recompute.py +++ b/tests/unit_tests/peft/test_recompute.py @@ -87,5 +87,3 @@ def test_maybe_enable_recompute_inputs_grad_patches_block(monkeypatch): # Second invocation should be a no-op (no duplicate patch) maybe_enable_recompute_inputs_grad(model, patched_registry) assert model.block.forward is patched_forward - - From b90af9e712265dd515a23c222f04986ac6d4a139 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Thu, 18 Dec 2025 12:19:14 -0800 Subject: [PATCH 3/4] update comments Signed-off-by: yaoyu-33 --- src/megatron/bridge/peft/base.py | 4 ++++ src/megatron/bridge/training/setup.py | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/megatron/bridge/peft/base.py b/src/megatron/bridge/peft/base.py index f7d18e441f..2379f69f56 100644 --- a/src/megatron/bridge/peft/base.py +++ b/src/megatron/bridge/peft/base.py @@ -21,6 +21,7 @@ import torch.nn as nn from megatron.core.transformer.module import MegatronModule +from megatron.bridge.peft.recompute import maybe_enable_recompute_inputs_grad from megatron.bridge.peft.walk_utils import walk @@ -107,6 +108,9 @@ def __call__(self, model: ModelType, training: bool = True) -> ModelType: model_to_walk = model walk(model_to_walk, self.transform) + if training: + maybe_enable_recompute_inputs_grad(model) + if not training: self.freeze_model(model, training=training) diff --git a/src/megatron/bridge/training/setup.py b/src/megatron/bridge/training/setup.py index 6b89713f98..6332f6793f 100644 --- a/src/megatron/bridge/training/setup.py +++ b/src/megatron/bridge/training/setup.py @@ -49,7 +49,6 @@ finalize_tensor_inspect_post_model_initialization, initialize_tensor_inspect_pre_model_initialization, ) -from megatron.bridge.peft.recompute import maybe_enable_recompute_inputs_grad @@ -418,7 +417,6 @@ def _apply_peft_transformation(peft, base_model: list[MegatronModule]) -> list[M """ print_rank_0("Applying PEFT transformation...") transformed_model = peft(base_model, training=True) - maybe_enable_recompute_inputs_grad(transformed_model) peft.set_params_to_save(transformed_model) # Log PEFT statistics From 8eb93ef9843679e94ae2cdc135a6bc2aeb13b622 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 19 Dec 2025 10:23:38 -0800 Subject: [PATCH 4/4] update Signed-off-by: yaoyu-33 --- tests/unit_tests/peft/test_recompute.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/peft/test_recompute.py b/tests/unit_tests/peft/test_recompute.py index ced33f3bc2..3742e36ee7 100644 --- a/tests/unit_tests/peft/test_recompute.py +++ b/tests/unit_tests/peft/test_recompute.py @@ -49,7 +49,10 @@ def __init__(self) -> None: self.base.weight.requires_grad = False # Trainable adapter parameter whose name contains ".adapter." - self.adapter = DummyAdapter() + # Use a ModuleDict with key "adapter" so that the full parameter + # name includes the expected substring (".adapter.") used by + # maybe_enable_recompute_inputs_grad. + self.adapter = torch.nn.ModuleDict({"adapter": DummyAdapter()}) def modules(self): for module in super().modules():