Skip to content

feat: add arg to enable dft in liger#3125

Merged
NanoCode012 merged 4 commits into
mainfrom
feat/liger-dft
Nov 10, 2025
Merged

feat: add arg to enable dft in liger#3125
NanoCode012 merged 4 commits into
mainfrom
feat/liger-dft

Conversation

@NanoCode012

@NanoCode012 NanoCode012 commented Sep 2, 2025

Copy link
Copy Markdown
Collaborator

Description

Adds support for linkedin/Liger-Kernel#860

Enable via liger_use_token_scaling: true

  • Requires upstream release from Liger, >0.6.2

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added optional setting to enable token scaling for fused linear cross-entropy. When turned on with the existing FLCE option, each token’s loss is scaled by its predicted probability (detached), and this behavior is enforced automatically during training.
  • Documentation

    • Updated usage example to include the new FLCE-specific option, demonstrating how to enable token scaling alongside fused linear cross-entropy.

@coderabbitai

coderabbitai Bot commented Sep 2, 2025

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Introduces an optional configuration liger_use_token_scaling for Liger’s fused linear cross-entropy (FLCE). Updates README usage. Adds the field to LigerArgs. In plugin pre_model_load, when both FLCE and token scaling are enabled, runtime patches force use_token_scaling=True for both the FLCE function and loss class initializer.

Changes

Cohort / File(s) Summary
Docs: Liger usage update
src/axolotl/integrations/liger/README.md
Adds FLCE-specific option example: liger_use_token_scaling: true under a new section, without altering existing options.
Config schema: LigerArgs
src/axolotl/integrations/liger/args.py
Adds optional bool field liger_use_token_scaling (default None) with description; imports Pydantic Field; no other logic changes.
Plugin: FLCE token-scaling patching
src/axolotl/integrations/liger/plugin.py
If liger_fused_linear_cross_entropy and liger_use_token_scaling are true, wraps functional.liger_fused_linear_cross_entropy and LigerFusedLinearCrossEntropyLoss.__init__ to inject use_token_scaling=True on every call; other branches unchanged.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • djsaunde
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/liger-dft

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented Sep 2, 2025

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://6911fa5228124ecc53bf17ff--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 0b2795f

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/axolotl/integrations/liger/README.md (1)

21-23: Clarify version/feature dependency for token scaling.

Readers need to know this only works with FLCE and a Liger-Kernel build that includes the token-scaling feature (PR #860+).

 # FLCE-specific
-liger_use_token_scaling: true
+# Requires Liger-Kernel with token-scaling support (PR #860+) and FLCE enabled
+liger_use_token_scaling: true
src/axolotl/integrations/liger/args.py (1)

38-46: Guard misuse and clarify description.

Warn when token scaling is set without FLCE, and note the dependency in the field description.

     liger_use_token_scaling: bool | None = Field(
         default=None,
         json_schema_extra={
             "description": (
-                "Enables use_token_scaling in fused_linear_cross_entropy. "
-                "When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
+                "Enables use_token_scaling in fused_linear_cross_entropy (FLCE). "
+                "When True, each token's loss is multiplied by its predicted probability (detached from gradients). "
+                "Requires `liger_fused_linear_cross_entropy: true` and a Liger-Kernel build with token-scaling support."
             )
         },
     )

Add a validator (outside this hunk) to warn when ineffective:

# place near other @model_validator(mode="before")
@model_validator(mode="before")
@classmethod
def check_token_scaling_requires_flce(cls, data):
    if data.get("liger_use_token_scaling") and not data.get("liger_fused_linear_cross_entropy"):
        LOG.warning(
            "`liger_use_token_scaling: true` has no effect unless `liger_fused_linear_cross_entropy: true` is also set."
        )
    return data
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 0094a2d and 4aafd5f.

📒 Files selected for processing (3)
  • src/axolotl/integrations/liger/README.md (1 hunks)
  • src/axolotl/integrations/liger/args.py (2 hunks)
  • src/axolotl/integrations/liger/plugin.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/integrations/liger/args.py (1)

38-46: LGTM: adds well-scoped opt-in flag with clear schema.

src/axolotl/integrations/liger/plugin.py (1)

51-76: Add missing inspect import and verify patch logic

  • In src/axolotl/integrations/liger/plugin.py, add import inspect before using inspect.signature.
  • Ensure the guard skips the monkey-patch when use_token_scaling isn’t in the function or __init__ signature; manually test in an environment with Liger-Kernel installed to confirm no TypeError is raised.

@codecov

codecov Bot commented Sep 2, 2025

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 42.85714% with 12 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/integrations/liger/plugin.py 7.69% 12 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian

winglian commented Sep 2, 2025

Copy link
Copy Markdown
Collaborator

This should be a draft, right? Since it needs a new Liger release

@NanoCode012 NanoCode012 marked this pull request as draft September 2, 2025 13:38
@NanoCode012 NanoCode012 marked this pull request as ready for review November 10, 2025 14:37
@NanoCode012 NanoCode012 merged commit 11eb365 into main Nov 10, 2025
20 of 22 checks passed
@NanoCode012 NanoCode012 deleted the feat/liger-dft branch November 10, 2025 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants