Skip to content

Extend invariant suite with gradient-checkpointing equivalence#5689

Merged
qgallouedec merged 6 commits into
mainfrom
invariance_grad_chkpt
May 4, 2026
Merged

Extend invariant suite with gradient-checkpointing equivalence#5689
qgallouedec merged 6 commits into
mainfrom
invariance_grad_chkpt

Conversation

@qgallouedec

@qgallouedec qgallouedec commented May 1, 2026

Copy link
Copy Markdown
Member

Adds gradient_checkpointing=False as a third member of the existing sft and dpo equivalence classes in tests/invariant/. GC trades compute for memory and should produce trajectories identical to the canonical (gc=False) within fp32 noise; divergence here would be a real bug in the recompute path.

$ pytest tests/invariant/test_invariant.py -m invariant -v -k grad_ckpt
=============================================================== test session starts ===============================================================
platform linux -- Python 3.13.2, pytest-9.0.3, pluggy-1.6.0 -- /fsx/qgallouedec/trl/.venv/bin/python
cachedir: .pytest_cache
rootdir: /fsx/qgallouedec/trl
configfile: pyproject.toml
plugins: cov-7.1.0, rerunfailures-15.1, anyio-4.13.0, datadir-1.8.0, xdist-3.8.0
collected 6 items / 4 deselected / 2 selected                                                                                                     

tests/invariant/test_invariant.py::test_invariant[sft_grad_ckpt] PASSED                                                                     [ 50%]
tests/invariant/test_invariant.py::test_invariant[dpo_grad_ckpt] PASSED                                                                     [100%]

=================================================== 2 passed, 4 deselected in 187.82s (0:03:07) ===================================================

Note

Low Risk
Low risk: changes are limited to the invariant test suite, adding extra configurations to validate trajectory equivalence when gradient_checkpointing is disabled, with no production code impact.

Overview
Extends the invariant equivalence tests for both sft and dpo by adding a third member configuration that runs with gradient_checkpointing=False and asserts its loss/grad-norm trajectory matches the class reference within existing tolerances.

Reviewed by Cursor Bugbot for commit faf55fb. Bugbot is set up for automated code reviews on this repo. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec merged commit e55d788 into main May 4, 2026
13 checks passed
@qgallouedec qgallouedec deleted the invariance_grad_chkpt branch May 4, 2026 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants