Skip to content

Align KTO with DPO: Align _precompute_ref_logps#5714

Merged
albertvillanova merged 8 commits into
mainfrom
align-kto-dpo-precompute_ref_logps
May 7, 2026
Merged

Align KTO with DPO: Align _precompute_ref_logps#5714
albertvillanova merged 8 commits into
mainfrom
align-kto-dpo-precompute_ref_logps

Conversation

@albertvillanova

@albertvillanova albertvillanova commented May 6, 2026

Copy link
Copy Markdown
Member

Align KTO with DPO: Align _precompute_ref_logps.

This PR introduces a caching mechanism for precomputing reference log probabilities in the KTOTrainer, which significantly improves efficiency by avoiding redundant computations. The main changes involve adding new imports, integrating a hash-based cache file system using numpy, and updating the dataset with cached results.

Part of:

Changes

Caching and efficiency improvements:

  • Added logic to cache computed reference log probabilities (reference_logps and reference_KL_logps) to a numpy .npz file, identified by a hash of the dataset and model. On subsequent runs, the code loads from cache if available, reducing redundant computation.
  • Used hash_module and datasets.fingerprint.Hasher to generate unique cache fingerprints based on the model and dataset, ensuring cache correctness.
  • Updated dataset columns with cached or newly computed log probabilities, and ensured new fingerprints are set for reproducibility.

Dependency and import updates:

  • Added imports for os, numpy, and hash_module to support caching and hashing functionality.

These changes collectively improve the performance and reproducibility of the reference log probability computation process.


Note

Medium Risk
Adds cross-run caching and fingerprint manipulation in KTOTrainer._precompute_ref_logps, which can affect training correctness if the cache key or synchronization is wrong and introduces file I/O in distributed runs.

Overview
KTOTrainer._precompute_ref_logps now caches precomputed reference log probabilities to a compressed .npz file keyed by a fingerprint of the dataset, the (ref) model weights (hash_module), and whether KL is computed.

On subsequent runs it loads reference_logps (and reference_KL_logps when enabled) from cache instead of recomputing, and updates dataset columns while setting new_fingerprint to reflect the cached content for reproducibility across runs/processes.

Reviewed by Cursor Bugbot for commit a396005. Bugbot is set up for automated code reviews on this repo. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit d4484bc. Configure here.

Comment thread trl/experimental/kto/kto_trainer.py Outdated
Comment thread trl/experimental/kto/kto_trainer.py Outdated

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good!

@albertvillanova albertvillanova merged commit acbd53f into main May 7, 2026
5 checks passed
@albertvillanova albertvillanova deleted the align-kto-dpo-precompute_ref_logps branch May 7, 2026 06:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants