Skip to content

Apply generic fused liger ce, cce, and tiledmlp for arbitrary models#2908

Merged
winglian merged 10 commits into
mainfrom
liger-generic
Jul 16, 2025
Merged

Apply generic fused liger ce, cce, and tiledmlp for arbitrary models#2908
winglian merged 10 commits into
mainfrom
liger-generic

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Jul 12, 2025

Description

While Liger doesn't have patches for all models, we know most models have model forwards very similar to llama, so let's patch models with a generic flce if possible and log.

This is useful for experimenting if it works out of the box for newer models like LFM2

Summary by CodeRabbit

  • New Features

    • Added support for flexible, memory-efficient forward passes and loss computation for causal language models, including integration with PEFT and FSDP.
    • Introduced a configuration option to control the use of the original MLP in tiled MLP setups.
    • Expanded compatibility for cut cross-entropy and tiled MLP patches to additional model architectures.
    • Added a utility for consistent model class name handling across integrations.
  • Improvements

    • Enhanced plugin mechanisms for LIGER and cut cross-entropy integrations, enabling more robust and dynamic patching based on model type.
    • Streamlined internal logic for dynamic model patching and class prefix derivation, improving maintainability and reducing errors.
  • Bug Fixes

    • Improved error handling and logging for unsupported or misconfigured model types during plugin patching.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jul 12, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

The changes restructure and modularize the integration of LIGER and related kernel patching mechanisms in the codebase. The LigerPlugin implementation is moved to a new file, generic patching utilities are introduced for causal language models, and several components now use a centralized helper for model class prefix resolution. Additional configuration options and patching logic are also introduced.

Changes

File(s) Change Summary
src/axolotl/integrations/liger/init.py Removed the LigerPlugin class and all logic; now only re-exports LigerArgs and LigerPlugin from submodules and declares __all__.
src/axolotl/integrations/liger/plugin.py New file: Contains the full LigerPlugin class with logic for kernel patching and model-specific integration, previously in __init__.py.
src/axolotl/integrations/liger/models/base.py New file: Adds lce_forward, lce_maybe_trainable_lm_head, _liger_for_causal_lm_loss, and patch_lce_forward for flexible forward/loss computation and dynamic patching of causal LM models.
src/axolotl/integrations/cut_cross_entropy/init.py Adds patch_llama_like method to CutCrossEntropyPlugin for generic patching of llama-like models; updates pre_model_load to call this method.
src/axolotl/loaders/patch_manager.py Updates _apply_tiled_mlp to pass a new use_original_mlp argument from config to patch_tiled_mlp.
src/axolotl/monkeypatch/tiled_mlp.py Refactors prefix derivation to use get_causal_lm_model_cls_prefix; replaces fixed weight list in MLP patch with a cached, dynamically collected list of trainable parameters.
src/axolotl/monkeypatch/lora_kernels.py
src/axolotl/integrations/kd/kernels/models.py
Both refactor model class prefix construction to use the new get_causal_lm_model_cls_prefix utility function instead of manual string manipulation.
src/axolotl/utils/schemas/config.py Adds a new optional boolean config field: tiled_mlp_use_original_mlp.
src/axolotl/utils/callbacks/models.py Adds the helper function get_causal_lm_model_cls_prefix for standardized model class prefix resolution.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Axolotl
    participant LigerPlugin
    participant Model
    participant PatchUtils

    User->>Axolotl: Provide config and model type
    Axolotl->>LigerPlugin: pre_model_load(cfg)
    LigerPlugin->>PatchUtils: Determine patch strategy (generic/specific)
    PatchUtils->>Model: Apply kernel patches/monkey patches
    Model-->>LigerPlugin: Model patched
    LigerPlugin-->>Axolotl: Patching complete
Loading

Possibly related PRs

Suggested labels

ready to merge

Suggested reviewers

  • NanoCode012
  • djsaunde

Poem

In the warren, code hops anew,
Plugins and patches—now neatly askew!
Helpers for models, kernels to match,
With prefixes tidy and logic to catch.
Modular burrows, configs to tweak,
This rabbit’s release is both nimble and sleek!
🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2ce0cf3 and 00571f9.

📒 Files selected for processing (2)
  • src/axolotl/integrations/cut_cross_entropy/__init__.py (3 hunks)
  • src/axolotl/integrations/liger/__init__.py (1 hunks)
✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
src/axolotl/integrations/liger/models/base.py (1)

173-190: Consider adding verification after patching.

The dynamic patching approach relies on consistent naming conventions. While error handling is present, consider verifying the patch was applied correctly.

Add verification after patching:

        model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")

        model_cls.forward = lce_forward
+        # Verify the patch was applied
+        if model_cls.forward != lce_forward:
+            raise RuntimeError(f"Failed to patch forward method for {model_type}")
    except (ImportError, AttributeError) as e:

Also, consider documenting which model naming patterns are supported to help users understand compatibility.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eb66255 and 76b55dd.

📒 Files selected for processing (3)
  • src/axolotl/integrations/liger/__init__.py (1 hunks)
  • src/axolotl/integrations/liger/models/base.py (1 hunks)
  • src/axolotl/integrations/liger/plugin.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/integrations/liger/__init__.py (1)
src/axolotl/integrations/liger/plugin.py (1)
  • LigerPlugin (17-181)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/integrations/liger/__init__.py (1)

22-22: LGTM! Good refactoring for better code organization.

Moving the LigerPlugin implementation to a dedicated plugin.py file improves modularity and maintains a clean separation between the module interface and implementation.

src/axolotl/integrations/liger/plugin.py (1)

168-177: Good implementation of generic FLCE patching for unknown models.

This implementation aligns well with the PR objectives by providing a fallback mechanism for models without specific Liger patches. The error handling and logging ensure graceful degradation when patching fails.

Comment thread src/axolotl/integrations/liger/plugin.py Outdated
Comment thread src/axolotl/integrations/liger/plugin.py Outdated
Comment thread src/axolotl/integrations/liger/models/base.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Jul 12, 2025

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
src/axolotl/integrations/liger/plugin.py (4)

19-19: Fix typo in class docstring.

There's a typo in "integraton" which should be "integration".

-    Plugin for LIGER integraton with Axolotl.
+    Plugin for LIGER integration with Axolotl.

96-100: Inefficient model instantiation for module extraction.

Creating an empty model instance just to extract the modeling module is inefficient and potentially error-prone. Consider using dynamic import instead.

-            with init_empty_weights():
-                model = AutoModelForCausalLM.from_pretrained(
-                    cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
-                )
-                modeling_mod = sys.modules[model.__class__.__module__]
+            # Dynamically import the modeling module based on model name
+            import importlib
+            model_name = cfg.base_model.split('/')[-1].lower()
+            if 'deepseek' in model_name:
+                modeling_mod = importlib.import_module('transformers.models.deepseek_v2.modeling_deepseek_v2')
+            else:
+                # Fallback to current approach if pattern matching fails
+                with init_empty_weights():
+                    model = AutoModelForCausalLM.from_pretrained(
+                        cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
+                    )
+                    modeling_mod = sys.modules[model.__class__.__module__]

25-179: Consider breaking down the large pre_model_load method.

The method is over 150 lines long, handling multiple concerns. Consider extracting model-specific handling into separate methods for better maintainability.

    def pre_model_load(self, cfg):
+        self._setup_torch_compile_patches(cfg)
+        self._validate_config(cfg)
+        
+        if not self._apply_generic_patches(cfg):
+            self._apply_model_specific_patches(cfg)
+    
+    def _setup_torch_compile_patches(self, cfg):
        if cfg.torch_compile:
            # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
            import liger_kernel.ops.fused_linear_cross_entropy

            patch_with_compile_disable(
                liger_kernel.ops.fused_linear_cross_entropy,
                "fused_linear_cross_entropy_forward",
            )
            patch_with_compile_disable(
                liger_kernel.ops.fused_linear_cross_entropy,
                "fused_linear_cross_entropy_backward",
            )
+
+    def _validate_config(self, cfg):
        from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
        from liger_kernel.transformers.functional import liger_cross_entropy
        from liger_kernel.transformers.layer_norm import LigerLayerNorm
        from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
        from liger_kernel.transformers.rms_norm import LigerRMSNorm
        from liger_kernel.transformers.rope import liger_rotary_pos_emb
        from liger_kernel.transformers.swiglu import LigerSwiGLUMLP

        if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
            raise ValueError(
                "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
            )

38-44: Consider moving imports to module level for better performance.

Importing inside the method means these imports are executed every time the method is called. Since these are core LIGER components, consider moving them to the module level.

+from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
+from liger_kernel.transformers.functional import liger_cross_entropy
+from liger_kernel.transformers.layer_norm import LigerLayerNorm
+from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
+from liger_kernel.transformers.rms_norm import LigerRMSNorm
+from liger_kernel.transformers.rope import liger_rotary_pos_emb
+from liger_kernel.transformers.swiglu import LigerSwiGLUMLP

 class LigerPlugin(BasePlugin):
     """
     Plugin for LIGER integration with Axolotl.
     """

     def get_input_args(self):
         return "axolotl.integrations.liger.LigerArgs"

     def pre_model_load(self, cfg):
         # ... rest of the method without the imports
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76b55dd and ccddc1e.

📒 Files selected for processing (1)
  • src/axolotl/integrations/liger/plugin.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/integrations/liger/plugin.py (2)

166-175: Excellent implementation of generic fallback patch.

This implementation aligns perfectly with the PR objectives to "apply generic fused liger ce for unknown models." The error handling and logging provide good user feedback when the generic patch succeeds or fails.


51-72: Well-designed dynamic parameter inspection.

The use of inspect.signature() to dynamically determine which parameters the LIGER functions support is elegant and maintains compatibility across different LIGER versions. The conditional parameter passing based on signature inspection is a robust approach.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (4)
src/axolotl/monkeypatch/tiled_mlp.py (1)

48-53: Consider moving parameter caching to initialization time.

While the lazy initialization works functionally, it adds overhead to the first forward pass and could have thread safety issues in multi-threaded environments.

Consider initializing _compute_params immediately after patching the forward method:

        mlp_cls.forward = tiled_mlp_forward
-        mlp_cls._compute_params = []  # pylint: disable=protected-access
+        # Initialize compute_params for all instances
+        original_init = mlp_cls.__init__
+        def patched_init(self, *args, **kwargs):
+            original_init(self, *args, **kwargs)
+            self._compute_params = [p for p in self.parameters() if p.requires_grad]
+        mlp_cls.__init__ = patched_init

This would eliminate the conditional check in the forward pass and ensure thread safety.

src/axolotl/integrations/cut_cross_entropy/__init__.py (3)

125-127: Consider avoiding protected member access.

Directly accessing _PATCH_OPTS violates encapsulation. Consider using a public API if available or documenting why this access is necessary.


130-134: Improve error handling specificity.

The error handling catches both ImportError and AttributeError but doesn't differentiate between them. Consider providing more specific error messages for each case.

-            except (ImportError, AttributeError) as e:
-                raise RuntimeError(
-                    f"Could not import ForCausalLM class for model_type: {model_type}. "
-                    f"Error: {str(e)}"
-                ) from e
+            except ImportError as e:
+                raise RuntimeError(
+                    f"Module '{module_path}' not found for model_type: {model_type}. "
+                    f"Error: {str(e)}"
+                ) from e
+            except AttributeError as e:
+                raise RuntimeError(
+                    f"Class '{model_cls_prefix}ForCausalLM' not found in module for model_type: {model_type}. "
+                    f"Error: {str(e)}"
+                ) from e

136-138: Document side effects of modifying global state.

The method modifies the global PATCH_FNS dictionary, which could have unintended side effects if called multiple times or from different contexts. Consider documenting this behavior or implementing safeguards.

Consider checking if the patch already exists before overwriting:

         if model_type not in PATCH_FNS:
             LOG.warning("Setting up generic cce patch for model type: %s", model_type)
             PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
+        else:
+            LOG.debug("Patch already exists for model type: %s", model_type)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ccddc1e and 0197363.

📒 Files selected for processing (4)
  • src/axolotl/integrations/cut_cross_entropy/__init__.py (3 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp.py (2 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/utils/schemas/config.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/loaders/patch_manager.py (1)

281-285: LGTM! Configuration parameter correctly passed through.

The addition of the use_original_mlp parameter enables flexible MLP implementation selection based on configuration, which aligns well with the PR's goal of supporting generic patches for unknown models.

src/axolotl/monkeypatch/tiled_mlp.py (1)

65-65: LGTM! Proper attribute initialization.

Initializing _compute_params as a class attribute ensures all instances have this attribute defined, preventing AttributeError on first access.

src/axolotl/integrations/cut_cross_entropy/__init__.py (2)

22-22: LGTM!

Import is correctly added to support the partial function usage.


84-98: LGTM!

The generic patch registration is correctly placed before the standard patching process.

Comment thread src/axolotl/monkeypatch/tiled_mlp.py
Comment thread src/axolotl/integrations/cut_cross_entropy/__init__.py
Comment thread src/axolotl/integrations/cut_cross_entropy/__init__.py
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jul 12, 2025

📖 Documentation Preview: https://687711e4d78eac4bfb8dbe6e--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 00571f9

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
src/axolotl/integrations/liger/models/base.py (1)

150-157: Fix incorrect variable reference in non-FSDP path.

The function uses self.lm_head instead of the lm_head parameter that was passed in and potentially unwrapped.

    # FSDP is not used so we can read the lm_head weights and call the kernel directly
    return _liger_for_causal_lm_loss(
-        lm_head=self.lm_head,
+        lm_head=lm_head,
        hidden_states=hidden_states,
        hidden_size=hidden_size,
        labels=labels,
        shift_labels=shift_labels,
        **loss_kwargs,
    )
🧹 Nitpick comments (1)
src/axolotl/integrations/liger/models/base.py (1)

44-58: Consider extracting configuration handling to reduce duplication.

The pattern of checking if parameter is None and falling back to config values is repeated multiple times. Consider extracting this to a helper function.

+def _resolve_config_param(param_value, config_value):
+    return param_value if param_value is not None else config_value
+
 # pylint: disable=duplicate-code
-output_attentions = (
-    output_attentions
-    if output_attentions is not None
-    else self.config.output_attentions
-)
-output_hidden_states = (
-    output_hidden_states
-    if output_hidden_states is not None
-    else self.config.output_hidden_states
-)
-
-return_dict = (
-    return_dict if return_dict is not None else self.config.use_return_dict
-)
+output_attentions = _resolve_config_param(output_attentions, self.config.output_attentions)
+output_hidden_states = _resolve_config_param(output_hidden_states, self.config.output_hidden_states)
+return_dict = _resolve_config_param(return_dict, self.config.use_return_dict)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0197363 and c83866d.

📒 Files selected for processing (7)
  • src/axolotl/integrations/cut_cross_entropy/__init__.py (3 hunks)
  • src/axolotl/integrations/liger/__init__.py (1 hunks)
  • src/axolotl/integrations/liger/models/base.py (1 hunks)
  • src/axolotl/integrations/liger/plugin.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp.py (2 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/integrations/liger/init.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/tiled_mlp.py
  • src/axolotl/integrations/cut_cross_entropy/init.py
  • src/axolotl/integrations/liger/plugin.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/integrations/liger/models/base.py (4)

1-14: LGTM! Well-structured imports and documentation.

The imports are properly organized and the module docstring clearly describes the purpose.


16-43: LGTM! Well-documented function signature.

The function signature and docstring provide clear documentation of the parameters, especially the detailed explanation of logits_to_keep parameter.


68-87: LGTM! Solid logic for handling logits computation.

The slicing logic for logits_to_keep correctly handles both int and tensor types, and the default skip_logits behavior is well-designed.


160-170: LGTM! Clean wrapper function.

The function properly delegates to LigerForCausalLMLoss with correct parameter mapping.

Comment on lines +173 to +189
def patch_lce_forward(
model_type,
):
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")

model_cls.forward = lce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation for model_type parameter and improve error handling.

The dynamic import approach is fragile and should include input validation. Also consider caching successful imports for performance.

+_PATCHED_MODELS = set()
+
 def patch_lce_forward(
     model_type,
 ):
+    if not model_type or not isinstance(model_type, str):
+        raise ValueError(f"model_type must be a non-empty string, got: {model_type}")
+    
+    # Avoid double-patching
+    if model_type in _PATCHED_MODELS:
+        return
+    
     try:
         # Dynamically import the module and MLP class
         module_path = f"transformers.models.{model_type}.modeling_{model_type}"
         model_cls_prefix = "".join(
             [part.capitalize() for part in model_type.split("_")]
         )
         module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
         model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
 
         model_cls.forward = lce_forward
+        _PATCHED_MODELS.add(model_type)
     # pylint: disable=duplicate-code
     except (ImportError, AttributeError) as e:
         raise RuntimeError(
             f"Could not import ForCausalLM class for model_type: {model_type}. "
             f"Error: {str(e)}"
         ) from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def patch_lce_forward(
model_type,
):
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = lce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
# Cache of model_types already patched to avoid re-applying
_PATCHED_MODELS = set()
def patch_lce_forward(
model_type,
):
# Validate input
if not model_type or not isinstance(model_type, str):
raise ValueError(f"model_type must be a non-empty string, got: {model_type}")
# Avoid double-patching
if model_type in _PATCHED_MODELS:
return
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = lce_forward
_PATCHED_MODELS.add(model_type)
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
🤖 Prompt for AI Agents
In src/axolotl/integrations/liger/models/base.py around lines 173 to 191, add
validation to check that the model_type parameter is a non-empty string and
matches expected patterns before attempting the dynamic import to prevent import
errors. Enhance error handling by providing clearer messages and consider
implementing a caching mechanism (e.g., a dictionary) to store and reuse
successfully imported model classes to improve performance and avoid repeated
imports.

Comment on lines +89 to +108
if skip_logits:
loss = lce_maybe_trainable_lm_head(
self,
hidden_states=kept_hidden_states,
hidden_size=self.config.hidden_size,
labels=labels,
shift_labels=shift_labels,
**kwargs,
)

else:
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Ensure loss_function is available when labels are provided.

The code assumes self.loss_function exists when computing loss in the else branch. Verify that all target models have this attribute or add error handling.


🏁 Script executed:

#!/bin/bash
# Verify if loss_function is a standard attribute in transformers causal LM models
ast-grep --pattern 'class $_ {
  $$$
  loss_function = $$$
  $$$
}' | head -20

# Also check for loss_function method definitions
ast-grep --pattern 'def loss_function($$$) {
  $$$
}' | head -20

Length of output: 147


🏁 Script executed:

#!/bin/bash
set -e

echo "Searching for loss_function definitions, assignments, and references across the repo..."

# Definitions of loss_function methods
rg -n "def loss_function" .

# Attribute assignments to loss_function
rg -n "loss_function\s*=" .

# References to self.loss_function
rg -n "self\.loss_function" .

Length of output: 223


Ensure loss_function is defined or add clear error handling
In the else branch you call self.loss_function(...) if labels is set, but no such method or attribute exists in this class (or its base classes)—this will raise an unexpected AttributeError at runtime. Please either provide a default implementation for loss_function or guard against its absence.

• File: src/axolotl/integrations/liger/models/base.py
• Lines: ~89–108

Example patch:

     else:
         logits = self.lm_head(kept_hidden_states)
         if labels is not None:
+            if not hasattr(self, 'loss_function'):
+                raise AttributeError(
+                    "Model must define `loss_function` to compute loss with labels"
+                )
             loss = self.loss_function(
                 logits=logits,
                 labels=labels,
                 vocab_size=self.config.vocab_size,
                 **kwargs,
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if skip_logits:
loss = lce_maybe_trainable_lm_head(
self,
hidden_states=kept_hidden_states,
hidden_size=self.config.hidden_size,
labels=labels,
shift_labels=shift_labels,
**kwargs,
)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if skip_logits:
loss = lce_maybe_trainable_lm_head(
self,
hidden_states=kept_hidden_states,
hidden_size=self.config.hidden_size,
labels=labels,
shift_labels=shift_labels,
**kwargs,
)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None:
if not hasattr(self, 'loss_function'):
raise AttributeError(
"Model must define `loss_function` to compute loss with labels"
)
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
🤖 Prompt for AI Agents
In src/axolotl/integrations/liger/models/base.py around lines 89 to 108, the
code calls self.loss_function when labels are provided, but this method or
attribute is not defined in the class or its base classes, which will cause an
AttributeError at runtime. To fix this, either implement a default loss_function
method in the class or add a check before calling it to ensure it exists,
raising a clear error or handling the case when it is absent.

@winglian winglian requested a review from NanoCode012 July 12, 2025 16:19
@winglian winglian changed the title Apply generic fused liger ce for unknown models Apply generic fused liger ce, cce, and tiledmlp for arbitrary models Jul 12, 2025
Comment thread src/axolotl/integrations/liger/models/base.py
@winglian winglian requested review from NanoCode012 and djsaunde July 14, 2025 11:54
Copy link
Copy Markdown
Collaborator

@djsaunde djsaunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool stuff!

Some questions / comments but otherwise looks good.

Comment thread src/axolotl/integrations/liger/__init__.py Outdated
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: missing return

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

going to drop this, the annotation lists the types, and upstream in liger, they just have a code sample there 🤷


def lce_forward(
self,
*args,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to support possibly variable-length args here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to seperate what's required from not, so generally all non-kwargs were required.

Comment thread src/axolotl/utils/schemas/config.py
Comment on lines +48 to +53
if not self._compute_params: # pylint: disable=protected-access
self._compute_params = [ # pylint: disable=protected-access
p for p in self.parameters() if p.requires_grad
]

compute_params = self._compute_params # pylint: disable=protected-access
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these pylint disables needed? We're accessing class members.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it was giving me some errors when I was linting

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
src/axolotl/integrations/liger/plugin.py (1)

108-111: Fix conflicting GLU activation logic for DeepseekV2.

The code warns that liger_glu_activation is not supported for DeepseekV2 but then applies it anyway. This is contradictory behavior.

Either remove the warning or skip the GLU activation patching:

            if cfg.liger_glu_activation:
                LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
            if cfg.liger_rms_norm:
                modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
-            if cfg.liger_glu_activation:
-                modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
src/axolotl/integrations/liger/models/base.py (3)

98-106: Ensure loss_function is defined or add clear error handling

In the else branch you call self.loss_function(...) if labels is set, but no such method or attribute exists in this class (or its base classes)—this will raise an unexpected AttributeError at runtime. Please either provide a default implementation for loss_function or guard against its absence.

Example patch:

     else:
         logits = self.lm_head(kept_hidden_states)
         if labels is not None:
+            if not hasattr(self, 'loss_function'):
+                raise AttributeError(
+                    "Model must define `loss_function` to compute loss with labels"
+                )
             loss = self.loss_function(
                 logits=logits,
                 labels=labels,
                 vocab_size=self.config.vocab_size,
                 **kwargs,
             )

148-155: Fix incorrect variable reference in non-FSDP path.

The function uses self.lm_head instead of the lm_head parameter that was passed in.

    # FSDP is not used so we can read the lm_head weights and call the kernel directly
    return _liger_for_causal_lm_loss(
-        lm_head=self.lm_head,
+        lm_head=lm_head,
        hidden_states=hidden_states,
        hidden_size=hidden_size,
        labels=labels,
        shift_labels=shift_labels,
        **loss_kwargs,
    )

171-189: Add validation for model_type parameter and improve error handling.

The dynamic import approach is fragile and should include input validation. Also consider caching successful imports for performance.

+_PATCHED_MODELS = set()
+
 def patch_lce_forward(
     model_type,
 ):
+    if not model_type or not isinstance(model_type, str):
+        raise ValueError(f"model_type must be a non-empty string, got: {model_type}")
+    
+    # Avoid double-patching
+    if model_type in _PATCHED_MODELS:
+        return
+    
     try:
         # Dynamically import the module and MLP class
         module_path = f"transformers.models.{model_type}.modeling_{model_type}"
         model_cls_prefix = "".join(
             [part.capitalize() for part in model_type.split("_")]
         )
         module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
         model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
 
         model_cls.forward = lce_forward
+        _PATCHED_MODELS.add(model_type)
     # pylint: disable=duplicate-code
     except (ImportError, AttributeError) as e:
         raise RuntimeError(
             f"Could not import ForCausalLM class for model_type: {model_type}. "
             f"Error: {str(e)}"
         ) from e
🧹 Nitpick comments (1)
src/axolotl/integrations/liger/plugin.py (1)

169-171: Fix typo in warning message.

There's a typo in the warning message.

                LOG.warning_once(
-                    f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"
+                    f"Applied ONLY liger_fused_linear_cross_entropy generic patches for model type: {cfg.model_config_type}"
                )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c83866d and 4f0cf11.

📒 Files selected for processing (4)
  • src/axolotl/integrations/cut_cross_entropy/__init__.py (3 hunks)
  • src/axolotl/integrations/liger/__init__.py (1 hunks)
  • src/axolotl/integrations/liger/models/base.py (1 hunks)
  • src/axolotl/integrations/liger/plugin.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/integrations/cut_cross_entropy/init.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (1)
src/axolotl/integrations/liger/__init__.py (1)

24-27: Great implementation of the suggested export pattern!

The addition of __all__ properly declares the public API and allows removal of the linter disables as suggested in the past review.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (5)
src/axolotl/integrations/cut_cross_entropy/__init__.py (1)

109-111: Address the unused parameter issue.

The maybe_model parameter is not used in the function body and should be removed as noted in previous reviews.

src/axolotl/integrations/liger/plugin.py (1)

110-111: Address the conflicting GLU activation logic for DeepseekV2.

Based on past review comments, there's conflicting behavior where the code warns that liger_glu_activation is not supported for DeepseekV2 but then applies it anyway on line 111.

src/axolotl/integrations/liger/models/base.py (3)

89-108: Missing loss_function attribute handling remains unresolved.

The code calls self.loss_function when labels are provided, but this method/attribute is not guaranteed to exist in the target models, which will cause an AttributeError at runtime.


150-157: Incorrect variable reference remains unresolved.

The function still uses self.lm_head instead of the lm_head parameter that was passed in and potentially unwrapped by PEFT handling.


173-189: Missing validation and caching mechanism remains unresolved.

The dynamic import approach is still fragile and lacks input validation for model_type parameter. The suggestion to add caching to prevent double-patching has not been implemented.

🧹 Nitpick comments (1)
src/axolotl/integrations/liger/plugin.py (1)

170-170: Fix typo in log message.

There's a typo in the log message: "genericpatches" should be "generic patches".

-                    f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"
+                    f"Applied ONLY liger_fused_linear_cross_entropy generic patches for model type: {cfg.model_config_type}"
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4f0cf11 and 2ce0cf3.

📒 Files selected for processing (11)
  • setup.py (1 hunks)
  • src/axolotl/integrations/cut_cross_entropy/__init__.py (3 hunks)
  • src/axolotl/integrations/kd/kernels/models.py (2 hunks)
  • src/axolotl/integrations/liger/__init__.py (1 hunks)
  • src/axolotl/integrations/liger/models/base.py (1 hunks)
  • src/axolotl/integrations/liger/plugin.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/lora_kernels.py (2 hunks)
  • src/axolotl/monkeypatch/tiled_mlp.py (3 hunks)
  • src/axolotl/utils/callbacks/models.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/integrations/liger/init.py
  • src/axolotl/monkeypatch/tiled_mlp.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
src/axolotl/integrations/kd/kernels/models.py (1)
src/axolotl/utils/callbacks/models.py (1)
  • get_causal_lm_model_cls_prefix (8-23)
src/axolotl/monkeypatch/lora_kernels.py (1)
src/axolotl/utils/callbacks/models.py (1)
  • get_causal_lm_model_cls_prefix (8-23)
src/axolotl/integrations/liger/models/base.py (2)
src/axolotl/utils/callbacks/models.py (1)
  • get_causal_lm_model_cls_prefix (8-23)
src/axolotl/integrations/kd/topk_logprob/forward_kl.py (1)
  • loss (23-96)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (15)
setup.py (1)

117-117: Flash-attn version constraint is compatible and secure

The minimum requirement flash-attn>=2.7.4.post1 aligns with the latest PyPI release (2.8.1) and there are no published security advisories for this package. The existing code and tests already import and exercise flash-attn APIs without issue.

No further changes needed here—this relaxation is safe.

src/axolotl/monkeypatch/lora_kernels.py (2)

21-21: LGTM: Consistent import addition.

The import of get_causal_lm_model_cls_prefix aligns with the centralization effort across the codebase.


157-157: LGTM: Improved consistency with centralized helper.

Replacing the manual prefix construction with get_causal_lm_model_cls_prefix improves code consistency and maintainability across the codebase.

src/axolotl/integrations/kd/kernels/models.py (2)

25-26: LGTM: Consistent centralization of model prefix logic.

The import and usage of get_causal_lm_model_cls_prefix improves consistency across the codebase.


102-102: LGTM: Refactored to use centralized helper.

The change from manual prefix construction to using the centralized helper function improves maintainability.

src/axolotl/integrations/cut_cross_entropy/__init__.py (5)

22-22: LGTM: Added import for partial function.

The partial import is properly added to support the new patching functionality.


30-30: LGTM: Consistent use of centralized helper.

The import of get_causal_lm_model_cls_prefix maintains consistency with the codebase refactoring.


89-89: LGTM: Good integration of generic patching.

The call to patch_llama_like before the standard cce_patch provides a good fallback mechanism for unsupported model types.


115-134: LGTM: Robust dynamic import with proper error handling.

The dynamic import logic correctly uses the centralized helper function and includes proper error handling with descriptive error messages. The try-catch block appropriately handles ImportError and AttributeError cases.


136-143: LGTM: Well-structured conditional patching with appropriate warnings.

The logic correctly checks if the model type already has a patch function and only registers the generic patch when needed. The warning messages appropriately inform users about the experimental nature of generic support.

src/axolotl/integrations/liger/models/base.py (5)

1-16: LGTM! Well-organized imports and good use of utility functions.

The imports are appropriate for the functionality and the use of get_causal_lm_model_cls_prefix aligns well with the centralized approach for dynamic model class handling.


18-88: LGTM! Well-structured forward method with good parameter handling.

The function properly handles output flags, conditional logits computation, and the skip_logits logic. The parameter validation and default value assignment are implemented correctly.


109-119: LGTM! Proper return value handling.

The function correctly handles both tuple and dictionary return formats based on the return_dict parameter.


122-149: LGTM! Proper PEFT and FSDP integration handling.

The function correctly unwraps PEFT ModulesToSaveWrapper and handles FSDP forward redirection with appropriate parameter passing.


160-170: LGTM! Clean wrapper function.

The function provides a clear interface to LigerForCausalLMLoss with appropriate parameter mapping.

Comment thread src/axolotl/utils/callbacks/models.py
Comment thread src/axolotl/integrations/liger/plugin.py
@winglian winglian merged commit 2c408b5 into main Jul 16, 2025
9 checks passed
@winglian winglian deleted the liger-generic branch July 16, 2025 02:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants