feat(offload): hidden_states activation offloading + fix legacy/full-param paths#3733
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR expands activation offloading from a boolean option to multiple modes, adds a hidden-states-specific checkpoint offload implementation, updates trainer and model-loading behavior for those modes, and adds end-to-end tests for stream selection, adapter handling, and hidden-states patching. ChangesActivation offloading modes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…param paths Add `activation_offloading: hidden_states` (ALST-style): gradient checkpointing that offloads only the per-layer input (hidden_states) to CPU, overlapped with compute via a side stream (async d2h with bounded in-flight + backward h2d prefetch). It replaces torch's reentrant CheckpointFunction, so it's framework-agnostic (validated on FSDP2 ZeRO-2/ZeRO-3) and DTensor-aware for sequence/context parallelism. Best for long-context full finetuning: memory is superior to GC and speed is competitive (Qwen3-8B full-param 16k-64k: 0.96-1.02x speed, 1.08-1.23x less reserved, widening with seq), and it reaches contexts where GC OOMs. Also fix the existing TRL-based offload paths: - decouple offload from recompute: `_apply_activation_checkpointing` is now adapter-aware — LoRA offloads instead of recomputing (pure offload is leaner and faster), full finetune keeps recompute + offloads checkpoint boundaries. - wire `activation_offloading: legacy`/`disk` (were dead: gated on `is True` and the arg type couldn't hold the string); legacy now selects the synchronous (non-streamed) TRL offloader. Validation forces use_reentrant=True for hidden_states and warns on LoRA.
72f6d0c to
394d2d3
Compare
|
📖 Documentation Preview: https://6a300eab32db846266636c23--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit d6ef5d9 |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (2)
tests/e2e/test_activation_offloading.py (1)
84-87: ⚡ Quick winAdd
'disk'to the mode-wiring parameterization.The new mode wiring test currently covers
Trueand"legacy"only; include"disk"so regressions in the third supported mode are caught.[Maintenance-only test coverage improvement]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/e2e/test_activation_offloading.py` around lines 84 - 87, The mode-wiring test parameterization in test_activation_offloading.py currently covers only True and "legacy"; update the pytest.mark.parametrize set in the test that uses offload_mode and expect_streams to also include "disk" so the third supported mode is exercised. Keep the existing expectations aligned with the test’s current wiring logic, and make sure the added case is reflected in the same parameter list used by the offloading activation test.src/axolotl/core/builders/base.py (1)
557-568: ⚡ Quick winTrim this block to short WHY-only comments per project style.
Lines 557–568 are detailed multi-line WHAT/flow narration; this file’s style rule asks for max one short line focused on non-obvious WHY.
As per coding guidelines,
src/axolotl/**/*.pycomments should explain only the WHY for non-obvious logic and be “a maximum of one short line.”🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/core/builders/base.py` around lines 557 - 568, The multi-line comments in the gradient checkpointing and activation offloading conditional blocks (lines 557-560 and 563-568) violate the project style guide by being too verbose and explaining WHAT/flow narration instead of WHY. Replace each comment block with a single short line that explains only the non-obvious reason for that branch: the first comment block above `training_args_kwargs["gradient_checkpointing"] = True` should briefly state WHY ALST-style keeps HF gradient checkpointing with reentrant enabled, and the second comment block above the `elif self.cfg.activation_offloading` should briefly state WHY the TRL offloader approach is different from HF recompute.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/core/training_args_base.py`:
- Around line 238-243: Update the activation_offloading help text in
TrainingArgsBase so it includes the new 'hidden_states' mode alongside the
existing True, 'legacy', 'disk', and False values. Adjust the metadata help
string on activation_offloading to match the accepted config values used by the
training-arg CLI/help output.
In `@src/axolotl/monkeypatch/activation_offload_checkpoint.py`:
- Line 226: The return statement at line 226 uses tuple concatenation with the +
operator which triggers a Ruff linting error. Replace the concatenation syntax
with tuple unpacking by using the * operator to unpack the grads tuple directly
within the tuple literal, changing from return (None, None) + grads to return
(None, None, *grads) to satisfy the Ruff lint requirements.
- Around line 55-57: Guard the CUDA stream creation in the
`_StreamOffloadManager.__init__` method to only execute on CUDA-capable systems.
Check the device type (which is inferred at line 134) before calling
`torch.cuda.Stream()` on line 57. If the device is not CUDA, either skip stream
initialization or set it to None to prevent crashes on CPU/NPU/MPS-only systems
where CUDA operations are not available.
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1444-1464: The check_hidden_states_offloading method in the after
validator forces use_reentrant=True without checking the unfrozen_parameters
incompatibility that is guarded in a before validator. Add a validation check in
the check_hidden_states_offloading method to ensure unfrozen_parameters is
compatible with the hidden_states activation offloading mode, matching the
safety logic from the before validator at line 456, so that invalid config
shapes are caught during parsing rather than failing at runtime.
In `@tests/e2e/test_activation_offloading.py`:
- Around line 104-106: The test file tests/e2e/test_activation_offloading.py has
formatting issues flagged by ruff-format in the monkeypatch.setattr call blocks.
Run ruff-format on the file to automatically fix the formatting, then commit the
resulting changes. This affects two locations: the monkeypatch.setattr call for
ActivationOffloadingMixin "training_step" around lines 104-106 and a similar
monkeypatch.setattr call around lines 169-171. Apply the ruff-format output to
both locations to ensure consistent formatting across the test file and unblock
CI.
---
Nitpick comments:
In `@src/axolotl/core/builders/base.py`:
- Around line 557-568: The multi-line comments in the gradient checkpointing and
activation offloading conditional blocks (lines 557-560 and 563-568) violate the
project style guide by being too verbose and explaining WHAT/flow narration
instead of WHY. Replace each comment block with a single short line that
explains only the non-obvious reason for that branch: the first comment block
above `training_args_kwargs["gradient_checkpointing"] = True` should briefly
state WHY ALST-style keeps HF gradient checkpointing with reentrant enabled, and
the second comment block above the `elif self.cfg.activation_offloading` should
briefly state WHY the TRL offloader approach is different from HF recompute.
In `@tests/e2e/test_activation_offloading.py`:
- Around line 84-87: The mode-wiring test parameterization in
test_activation_offloading.py currently covers only True and "legacy"; update
the pytest.mark.parametrize set in the test that uses offload_mode and
expect_streams to also include "disk" so the third supported mode is exercised.
Keep the existing expectations aligned with the test’s current wiring logic, and
make sure the added case is reflected in the same parameter list used by the
offloading activation test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 377647c0-7ba3-4cbf-8fad-d7c4c80a23f9
📒 Files selected for processing (8)
src/axolotl/core/builders/base.pysrc/axolotl/core/trainers/mixins/activation_checkpointing.pysrc/axolotl/core/training_args_base.pysrc/axolotl/loaders/model.pysrc/axolotl/monkeypatch/activation_offload_checkpoint.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/validation.pytests/e2e/test_activation_offloading.py
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
…ts/test - hidden_states forces use_reentrant=True, which is incompatible with a partially frozen model (unfrozen_parameters). The before-validator guard ran before the force, so catch it in check_hidden_states_offloading. - shorten verbose branch comments in _configure_gradient_checkpointing. - cover the 'disk' offload mode in the mode-wiring test.
| @@ -30,13 +30,18 @@ class ActivationOffloadingMixin(Trainer): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| if self.args.activation_offloading: | |||
There was a problem hiding this comment.
Could we update the https://docs.axolotl.ai/docs/gradient_checkpointing.html docs as well for this new function and when to choose each?
| json_schema_extra={ | ||
| "description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'." | ||
| }, | ||
| activation_offloading: Literal["legacy", "disk", "hidden_states"] | bool | None = ( |
There was a problem hiding this comment.
I think in patch manager, we check for "offload_disk" whereas the pydantic model uses disk
| s0 = torch.cuda.current_stream() | ||
| self.s1.wait_stream(s0) # input must be produced before we copy it | ||
| with torch.cuda.stream(self.s1): | ||
| cpu_t = t.detach().to("cpu", non_blocking=True) | ||
| ev = self.s1.record_event() |
There was a problem hiding this comment.
Claude's suggestion, potential perf improvement:
torch.empty(..., pin_memory=True) + .copy_(t, non_blocking=True)
upstream TRL creates an empty pinned buffer and copies to it unlike in this case .to("cpu", non_blocking=True) where it targets a pageable memory and is sync
There was a problem hiding this comment.
pinned memory usage is way worse and seems to provide no improvement on any axis.
Address review feedback on hidden_states activation offloading: - restore(): record_stream the brought-back input on the compute stream so the allocator can't reuse its storage before the recompute consumes it (use-after-free guard, matching TRL's offloader). - training_args activation_offloading typed Literal['legacy','disk'] | bool; 'hidden_states' never reaches the trainer (handled in the model loader). - add numerical-parity test: hidden_states offload vs plain reentrant checkpointing must match loss/grads (CUDA-guarded). - document the hidden_states mode and a mode-selection table. Investigated pinned-host-buffer d2h: benchmarked vs the current pageable copy (Qwen-sized MLP stack, 8k-32k). No speed gain and higher peak GPU memory (+2/+4/+8 GB, growing with seq) since true async holds the GPU source longer — a regression for a memory-reduction feature. Keeping pageable. Disk-offload mode wiring left for a follow-up PR.
Summary
Adds
activation_offloading: hidden_states— an ALST-style activation offloader — and fixes the existing TRL-based offload paths.hidden_statesis gradient checkpointing that offloads only the per-layer input (hidden_states) to CPU and recomputes the intermediates, overlapping the transfer with compute on a side stream (async d2h with bounded in-flight + backward h2d prefetch). Unlike the existing TRL offloader (which offloads every saved tensor and is PCIe-saturated for full-parameter training), this moves one tensor/layer, so PCIe stays within budget and at long sequence the copy hides behind the recompute.It replaces torch's reentrant
CheckpointFunction, so it is framework-agnostic (no DeepSpeed dependency) and DTensor-aware (sequence/context parallel).Why
For long-context full-parameter training, offloading everything (existing
activation_offloading: true) is slower and uses more GPU memory than plain gradient checkpointing (the offload backlog exceeds PCIe bandwidth). Offloading only the checkpoint input is the right granularity.Results (Qwen3-8B full-param, single GPU, bf16)
Speed is tied with GC; the memory advantage widens with sequence length.
Also fixes the existing TRL offload paths
_apply_activation_checkpointingis now adapter-aware: LoRA offloads instead of recomputing (pure offload is leaner and faster — recompute pins the offloaded tensors and balloons memory); full finetune keeps recompute and offloads only the checkpoint boundaries.legacy/disk. These string modes were dead (gated onis True, and the trainer arg typebool | Nonecouldn't hold the string).legacynow selects the synchronous (non-streamed) TRL offloader.Validation
0.0in fp32).reshard_after_forward: true) and ZeRO-2 (false) — reentrant checkpoint + offload + FSDP2 compose. (ZeRO-1 is not an FSDP2 sharding mode.)use_reentrant=Trueforhidden_statesand warns when used with LoRA/QLoRA.test_hidden_states_offload_full_param,test_offload_mode_wiring,test_offload_is_adapter_aware+ existing offloading/leak tests (11 pass).Usage
Summary by CodeRabbit
New Features
Bug Fixes