feat: add arg to enable dft in liger#3125
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughIntroduces an optional configuration liger_use_token_scaling for Liger’s fused linear cross-entropy (FLCE). Updates README usage. Adds the field to LigerArgs. In plugin pre_model_load, when both FLCE and token scaling are enabled, runtime patches force use_token_scaling=True for both the FLCE function and loss class initializer. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
📖 Documentation Preview: https://6911fa5228124ecc53bf17ff--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 0b2795f |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/axolotl/integrations/liger/README.md (1)
21-23: Clarify version/feature dependency for token scaling.Readers need to know this only works with FLCE and a Liger-Kernel build that includes the token-scaling feature (PR #860+).
# FLCE-specific -liger_use_token_scaling: true +# Requires Liger-Kernel with token-scaling support (PR #860+) and FLCE enabled +liger_use_token_scaling: truesrc/axolotl/integrations/liger/args.py (1)
38-46: Guard misuse and clarify description.Warn when token scaling is set without FLCE, and note the dependency in the field description.
liger_use_token_scaling: bool | None = Field( default=None, json_schema_extra={ "description": ( - "Enables use_token_scaling in fused_linear_cross_entropy. " - "When True, each token's loss is multiplied by its predicted probability (detached from gradients)." + "Enables use_token_scaling in fused_linear_cross_entropy (FLCE). " + "When True, each token's loss is multiplied by its predicted probability (detached from gradients). " + "Requires `liger_fused_linear_cross_entropy: true` and a Liger-Kernel build with token-scaling support." ) }, )Add a validator (outside this hunk) to warn when ineffective:
# place near other @model_validator(mode="before") @model_validator(mode="before") @classmethod def check_token_scaling_requires_flce(cls, data): if data.get("liger_use_token_scaling") and not data.get("liger_fused_linear_cross_entropy"): LOG.warning( "`liger_use_token_scaling: true` has no effect unless `liger_fused_linear_cross_entropy: true` is also set." ) return data
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
src/axolotl/integrations/liger/README.md(1 hunks)src/axolotl/integrations/liger/args.py(2 hunks)src/axolotl/integrations/liger/plugin.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/integrations/liger/args.py (1)
38-46: LGTM: adds well-scoped opt-in flag with clear schema.src/axolotl/integrations/liger/plugin.py (1)
51-76: Add missinginspectimport and verify patch logic
- In
src/axolotl/integrations/liger/plugin.py, addimport inspectbefore usinginspect.signature.- Ensure the guard skips the monkey-patch when
use_token_scalingisn’t in the function or__init__signature; manually test in an environment with Liger-Kernel installed to confirm no TypeError is raised.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
This should be a draft, right? Since it needs a new Liger release |
Description
Adds support for linkedin/Liger-Kernel#860
Enable via
liger_use_token_scaling: true>0.6.2Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Documentation