Skip to content

feat(offload): hidden_states activation offloading + fix legacy/full-param paths#3733

Merged
winglian merged 3 commits into
mainfrom
feat/activation-offloading-hidden-states
Jun 15, 2026
Merged

feat(offload): hidden_states activation offloading + fix legacy/full-param paths#3733
winglian merged 3 commits into
mainfrom
feat/activation-offloading-hidden-states

Conversation

@winglian

@winglian winglian commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds activation_offloading: hidden_states — an ALST-style activation offloader — and fixes the existing TRL-based offload paths.

hidden_states is gradient checkpointing that offloads only the per-layer input (hidden_states) to CPU and recomputes the intermediates, overlapping the transfer with compute on a side stream (async d2h with bounded in-flight + backward h2d prefetch). Unlike the existing TRL offloader (which offloads every saved tensor and is PCIe-saturated for full-parameter training), this moves one tensor/layer, so PCIe stays within budget and at long sequence the copy hides behind the recompute.

It replaces torch's reentrant CheckpointFunction, so it is framework-agnostic (no DeepSpeed dependency) and DTensor-aware (sequence/context parallel).

Why

For long-context full-parameter training, offloading everything (existing activation_offloading: true) is slower and uses more GPU memory than plain gradient checkpointing (the offload backlog exceeds PCIe bandwidth). Offloading only the checkpoint input is the right granularity.

Results (Qwen3-8B full-param, single GPU, bf16)

seq GC-only hidden_states vs GC
16k 63.3 GB / 2325 tok/s 58.9 GB / 2243 0.96× speed, 1.08× less mem
32k 73.1 GB / 1954 64.3 GB / 2000 1.02× speed, 1.14× less mem
64k 92.9 GB / 1605 75.7 GB / 1565 0.98× speed, 1.23× less mem
128k OOM fits reaches context where GC OOMs

Speed is tied with GC; the memory advantage widens with sequence length.

Also fixes the existing TRL offload paths

  • Decouple offload from recompute. _apply_activation_checkpointing is now adapter-aware: LoRA offloads instead of recomputing (pure offload is leaner and faster — recompute pins the offloaded tensors and balloons memory); full finetune keeps recompute and offloads only the checkpoint boundaries.
  • Wire legacy/disk. These string modes were dead (gated on is True, and the trainer arg type bool | None couldn't hold the string). legacy now selects the synchronous (non-streamed) TRL offloader.

Validation

  • Bit-exact gradients vs plain checkpointing (forward output, input grad, weight grads all 0.0 in fp32).
  • FSDP2, 2×GPU, full-parameter: runs clean under ZeRO-3 (reshard_after_forward: true) and ZeRO-2 (false) — reentrant checkpoint + offload + FSDP2 compose. (ZeRO-1 is not an FSDP2 sharding mode.)
  • Validation forces use_reentrant=True for hidden_states and warns when used with LoRA/QLoRA.
  • Tests: new test_hidden_states_offload_full_param, test_offload_mode_wiring, test_offload_is_adapter_aware + existing offloading/leak tests (11 pass).

Usage

gradient_checkpointing: true
activation_offloading: hidden_states   # full-parameter long-context

Summary by CodeRabbit

  • New Features

    • Added a new activation offloading mode that keeps checkpointed hidden states on CPU with streamed transfer support.
    • Expanded activation offloading options to support additional modes, including legacy and disk-based behavior.
  • Bug Fixes

    • Improved checkpointing behavior so the correct offloading and recomputation path is selected automatically.
    • Added safeguards for compatible checkpointing settings when hidden-states offloading is enabled.

@coderabbitai

coderabbitai Bot commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bceb890e-e261-421f-9f52-f3f7dea8d974

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR expands activation offloading from a boolean option to multiple modes, adds a hidden-states-specific checkpoint offload implementation, updates trainer and model-loading behavior for those modes, and adds end-to-end tests for stream selection, adapter handling, and hidden-states patching.

Changes

Activation offloading modes

Layer / File(s) Summary
Config and runtime wiring
src/axolotl/core/training_args_base.py, src/axolotl/utils/schemas/config.py, src/axolotl/utils/schemas/validation.py, src/axolotl/core/builders/base.py, src/axolotl/core/trainers/mixins/activation_checkpointing.py, src/axolotl/loaders/model.py
activation_offloading now accepts string modes, validation forces reentrant checkpointing for hidden_states, trainer setup distinguishes hidden-states from other truthy modes, the trainer mixin disables streams for legacy, and model loading either patches hidden-states checkpointing or applies manual wrapping for non-adapter offloading.
Hidden-states checkpoint implementation
src/axolotl/monkeypatch/activation_offload_checkpoint.py
Adds a custom checkpoint function and stream manager that offload the first eligible tensor input to CPU during forward, restore it during backward recomputation, preserve RNG and autocast state, and patch or unpatch torch.utils.checkpoint.CheckpointFunction.
End-to-end coverage
tests/e2e/test_activation_offloading.py
Adds E2E tests for offload mode stream behavior, adapter-aware checkpoint wrapping, and hidden-states mode forcing reentrant checkpointing while patching and restoring the checkpoint function.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • NanoCode012
  • SalmanMohammadi
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title clearly and specifically describes the main changes: introducing hidden_states activation offloading and fixing legacy/full-param paths, matching the core objectives.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/activation-offloading-hidden-states

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

…param paths

Add `activation_offloading: hidden_states` (ALST-style): gradient checkpointing
that offloads only the per-layer input (hidden_states) to CPU, overlapped with
compute via a side stream (async d2h with bounded in-flight + backward h2d
prefetch). It replaces torch's reentrant CheckpointFunction, so it's
framework-agnostic (validated on FSDP2 ZeRO-2/ZeRO-3) and DTensor-aware for
sequence/context parallelism. Best for long-context full finetuning: memory is
superior to GC and speed is competitive (Qwen3-8B full-param 16k-64k: 0.96-1.02x
speed, 1.08-1.23x less reserved, widening with seq), and it reaches contexts
where GC OOMs.

Also fix the existing TRL-based offload paths:
- decouple offload from recompute: `_apply_activation_checkpointing` is now
  adapter-aware — LoRA offloads instead of recomputing (pure offload is leaner
  and faster), full finetune keeps recompute + offloads checkpoint boundaries.
- wire `activation_offloading: legacy`/`disk` (were dead: gated on `is True`
  and the arg type couldn't hold the string); legacy now selects the synchronous
  (non-streamed) TRL offloader.

Validation forces use_reentrant=True for hidden_states and warns on LoRA.
@winglian winglian force-pushed the feat/activation-offloading-hidden-states branch from 72f6d0c to 394d2d3 Compare June 13, 2026 23:11
@github-actions

github-actions Bot commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://6a300eab32db846266636c23--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit d6ef5d9

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
tests/e2e/test_activation_offloading.py (1)

84-87: ⚡ Quick win

Add 'disk' to the mode-wiring parameterization.

The new mode wiring test currently covers True and "legacy" only; include "disk" so regressions in the third supported mode are caught.

[Maintenance-only test coverage improvement]

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/e2e/test_activation_offloading.py` around lines 84 - 87, The
mode-wiring test parameterization in test_activation_offloading.py currently
covers only True and "legacy"; update the pytest.mark.parametrize set in the
test that uses offload_mode and expect_streams to also include "disk" so the
third supported mode is exercised. Keep the existing expectations aligned with
the test’s current wiring logic, and make sure the added case is reflected in
the same parameter list used by the offloading activation test.
src/axolotl/core/builders/base.py (1)

557-568: ⚡ Quick win

Trim this block to short WHY-only comments per project style.

Lines 557–568 are detailed multi-line WHAT/flow narration; this file’s style rule asks for max one short line focused on non-obvious WHY.

As per coding guidelines, src/axolotl/**/*.py comments should explain only the WHY for non-obvious logic and be “a maximum of one short line.”

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/core/builders/base.py` around lines 557 - 568, The multi-line
comments in the gradient checkpointing and activation offloading conditional
blocks (lines 557-560 and 563-568) violate the project style guide by being too
verbose and explaining WHAT/flow narration instead of WHY. Replace each comment
block with a single short line that explains only the non-obvious reason for
that branch: the first comment block above
`training_args_kwargs["gradient_checkpointing"] = True` should briefly state WHY
ALST-style keeps HF gradient checkpointing with reentrant enabled, and the
second comment block above the `elif self.cfg.activation_offloading` should
briefly state WHY the TRL offloader approach is different from HF recompute.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/core/training_args_base.py`:
- Around line 238-243: Update the activation_offloading help text in
TrainingArgsBase so it includes the new 'hidden_states' mode alongside the
existing True, 'legacy', 'disk', and False values. Adjust the metadata help
string on activation_offloading to match the accepted config values used by the
training-arg CLI/help output.

In `@src/axolotl/monkeypatch/activation_offload_checkpoint.py`:
- Line 226: The return statement at line 226 uses tuple concatenation with the +
operator which triggers a Ruff linting error. Replace the concatenation syntax
with tuple unpacking by using the * operator to unpack the grads tuple directly
within the tuple literal, changing from return (None, None) + grads to return
(None, None, *grads) to satisfy the Ruff lint requirements.
- Around line 55-57: Guard the CUDA stream creation in the
`_StreamOffloadManager.__init__` method to only execute on CUDA-capable systems.
Check the device type (which is inferred at line 134) before calling
`torch.cuda.Stream()` on line 57. If the device is not CUDA, either skip stream
initialization or set it to None to prevent crashes on CPU/NPU/MPS-only systems
where CUDA operations are not available.

In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1444-1464: The check_hidden_states_offloading method in the after
validator forces use_reentrant=True without checking the unfrozen_parameters
incompatibility that is guarded in a before validator. Add a validation check in
the check_hidden_states_offloading method to ensure unfrozen_parameters is
compatible with the hidden_states activation offloading mode, matching the
safety logic from the before validator at line 456, so that invalid config
shapes are caught during parsing rather than failing at runtime.

In `@tests/e2e/test_activation_offloading.py`:
- Around line 104-106: The test file tests/e2e/test_activation_offloading.py has
formatting issues flagged by ruff-format in the monkeypatch.setattr call blocks.
Run ruff-format on the file to automatically fix the formatting, then commit the
resulting changes. This affects two locations: the monkeypatch.setattr call for
ActivationOffloadingMixin "training_step" around lines 104-106 and a similar
monkeypatch.setattr call around lines 169-171. Apply the ruff-format output to
both locations to ensure consistent formatting across the test file and unblock
CI.

---

Nitpick comments:
In `@src/axolotl/core/builders/base.py`:
- Around line 557-568: The multi-line comments in the gradient checkpointing and
activation offloading conditional blocks (lines 557-560 and 563-568) violate the
project style guide by being too verbose and explaining WHAT/flow narration
instead of WHY. Replace each comment block with a single short line that
explains only the non-obvious reason for that branch: the first comment block
above `training_args_kwargs["gradient_checkpointing"] = True` should briefly
state WHY ALST-style keeps HF gradient checkpointing with reentrant enabled, and
the second comment block above the `elif self.cfg.activation_offloading` should
briefly state WHY the TRL offloader approach is different from HF recompute.

In `@tests/e2e/test_activation_offloading.py`:
- Around line 84-87: The mode-wiring test parameterization in
test_activation_offloading.py currently covers only True and "legacy"; update
the pytest.mark.parametrize set in the test that uses offload_mode and
expect_streams to also include "disk" so the third supported mode is exercised.
Keep the existing expectations aligned with the test’s current wiring logic, and
make sure the added case is reflected in the same parameter list used by the
offloading activation test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 377647c0-7ba3-4cbf-8fad-d7c4c80a23f9

📥 Commits

Reviewing files that changed from the base of the PR and between 22bcb9a and 72f6d0c.

📒 Files selected for processing (8)
  • src/axolotl/core/builders/base.py
  • src/axolotl/core/trainers/mixins/activation_checkpointing.py
  • src/axolotl/core/training_args_base.py
  • src/axolotl/loaders/model.py
  • src/axolotl/monkeypatch/activation_offload_checkpoint.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/schemas/validation.py
  • tests/e2e/test_activation_offloading.py

Comment thread src/axolotl/core/training_args_base.py Outdated
Comment thread src/axolotl/monkeypatch/activation_offload_checkpoint.py
Comment thread src/axolotl/monkeypatch/activation_offload_checkpoint.py
Comment thread src/axolotl/utils/schemas/validation.py
Comment thread tests/e2e/test_activation_offloading.py Outdated
@codecov

codecov Bot commented Jun 13, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 35.55556% with 116 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...olotl/monkeypatch/activation_offload_checkpoint.py 21.67% 112 Missing ⚠️
src/axolotl/utils/schemas/validation.py 82.35% 3 Missing ⚠️
src/axolotl/core/builders/base.py 90.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

…ts/test

- hidden_states forces use_reentrant=True, which is incompatible with a
  partially frozen model (unfrozen_parameters). The before-validator guard
  ran before the force, so catch it in check_hidden_states_offloading.
- shorten verbose branch comments in _configure_gradient_checkpointing.
- cover the 'disk' offload mode in the mode-wiring test.
@@ -30,13 +30,18 @@ class ActivationOffloadingMixin(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.args.activation_offloading:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we update the https://docs.axolotl.ai/docs/gradient_checkpointing.html docs as well for this new function and when to choose each?

json_schema_extra={
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
},
activation_offloading: Literal["legacy", "disk", "hidden_states"] | bool | None = (

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in patch manager, we check for "offload_disk" whereas the pydantic model uses disk

Comment thread src/axolotl/monkeypatch/activation_offload_checkpoint.py
Comment thread src/axolotl/core/training_args_base.py Outdated
Comment thread tests/e2e/test_activation_offloading.py
Comment on lines +69 to +73
s0 = torch.cuda.current_stream()
self.s1.wait_stream(s0) # input must be produced before we copy it
with torch.cuda.stream(self.s1):
cpu_t = t.detach().to("cpu", non_blocking=True)
ev = self.s1.record_event()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude's suggestion, potential perf improvement:

torch.empty(..., pin_memory=True) + .copy_(t, non_blocking=True)

upstream TRL creates an empty pinned buffer and copies to it unlike in this case .to("cpu", non_blocking=True) where it targets a pageable memory and is sync

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pinned memory usage is way worse and seems to provide no improvement on any axis.

Address review feedback on hidden_states activation offloading:
- restore(): record_stream the brought-back input on the compute stream so the
  allocator can't reuse its storage before the recompute consumes it
  (use-after-free guard, matching TRL's offloader).
- training_args activation_offloading typed Literal['legacy','disk'] | bool;
  'hidden_states' never reaches the trainer (handled in the model loader).
- add numerical-parity test: hidden_states offload vs plain reentrant
  checkpointing must match loss/grads (CUDA-guarded).
- document the hidden_states mode and a mode-selection table.

Investigated pinned-host-buffer d2h: benchmarked vs the current pageable copy
(Qwen-sized MLP stack, 8k-32k). No speed gain and higher peak GPU memory
(+2/+4/+8 GB, growing with seq) since true async holds the GPU source longer —
a regression for a memory-reduction feature. Keeping pageable. Disk-offload
mode wiring left for a follow-up PR.
@winglian winglian merged commit 277d524 into main Jun 15, 2026
31 of 32 checks passed
@winglian winglian deleted the feat/activation-offloading-hidden-states branch June 15, 2026 22:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants