Apply generic fused liger ce, cce, and tiledmlp for arbitrary models#2908
Conversation
|
Caution Review failedThe pull request is closed. WalkthroughThe changes restructure and modularize the integration of LIGER and related kernel patching mechanisms in the codebase. The LigerPlugin implementation is moved to a new file, generic patching utilities are introduced for causal language models, and several components now use a centralized helper for model class prefix resolution. Additional configuration options and patching logic are also introduced. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Axolotl
participant LigerPlugin
participant Model
participant PatchUtils
User->>Axolotl: Provide config and model type
Axolotl->>LigerPlugin: pre_model_load(cfg)
LigerPlugin->>PatchUtils: Determine patch strategy (generic/specific)
PatchUtils->>Model: Apply kernel patches/monkey patches
Model-->>LigerPlugin: Model patched
LigerPlugin-->>Axolotl: Patching complete
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
src/axolotl/integrations/liger/models/base.py (1)
173-190: Consider adding verification after patching.The dynamic patching approach relies on consistent naming conventions. While error handling is present, consider verifying the patch was applied correctly.
Add verification after patching:
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") model_cls.forward = lce_forward + # Verify the patch was applied + if model_cls.forward != lce_forward: + raise RuntimeError(f"Failed to patch forward method for {model_type}") except (ImportError, AttributeError) as e:Also, consider documenting which model naming patterns are supported to help users understand compatibility.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/integrations/liger/__init__.py(1 hunks)src/axolotl/integrations/liger/models/base.py(1 hunks)src/axolotl/integrations/liger/plugin.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/integrations/liger/__init__.py (1)
src/axolotl/integrations/liger/plugin.py (1)
LigerPlugin(17-181)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/integrations/liger/__init__.py (1)
22-22: LGTM! Good refactoring for better code organization.Moving the
LigerPluginimplementation to a dedicatedplugin.pyfile improves modularity and maintains a clean separation between the module interface and implementation.src/axolotl/integrations/liger/plugin.py (1)
168-177: Good implementation of generic FLCE patching for unknown models.This implementation aligns well with the PR objectives by providing a fallback mechanism for models without specific Liger patches. The error handling and logging ensure graceful degradation when patching fails.
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (4)
src/axolotl/integrations/liger/plugin.py (4)
19-19: Fix typo in class docstring.There's a typo in "integraton" which should be "integration".
- Plugin for LIGER integraton with Axolotl. + Plugin for LIGER integration with Axolotl.
96-100: Inefficient model instantiation for module extraction.Creating an empty model instance just to extract the modeling module is inefficient and potentially error-prone. Consider using dynamic import instead.
- with init_empty_weights(): - model = AutoModelForCausalLM.from_pretrained( - cfg.base_model, trust_remote_code=cfg.trust_remote_code or False - ) - modeling_mod = sys.modules[model.__class__.__module__] + # Dynamically import the modeling module based on model name + import importlib + model_name = cfg.base_model.split('/')[-1].lower() + if 'deepseek' in model_name: + modeling_mod = importlib.import_module('transformers.models.deepseek_v2.modeling_deepseek_v2') + else: + # Fallback to current approach if pattern matching fails + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + cfg.base_model, trust_remote_code=cfg.trust_remote_code or False + ) + modeling_mod = sys.modules[model.__class__.__module__]
25-179: Consider breaking down the large pre_model_load method.The method is over 150 lines long, handling multiple concerns. Consider extracting model-specific handling into separate methods for better maintainability.
def pre_model_load(self, cfg): + self._setup_torch_compile_patches(cfg) + self._validate_config(cfg) + + if not self._apply_generic_patches(cfg): + self._apply_model_specific_patches(cfg) + + def _setup_torch_compile_patches(self, cfg): if cfg.torch_compile: # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled import liger_kernel.ops.fused_linear_cross_entropy patch_with_compile_disable( liger_kernel.ops.fused_linear_cross_entropy, "fused_linear_cross_entropy_forward", ) patch_with_compile_disable( liger_kernel.ops.fused_linear_cross_entropy, "fused_linear_cross_entropy_backward", ) + + def _validate_config(self, cfg): from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import LigerSwiGLUMLP if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy: raise ValueError( "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." )
38-44: Consider moving imports to module level for better performance.Importing inside the method means these imports are executed every time the method is called. Since these are core LIGER components, consider moving them to the module level.
+from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.transformers.functional import liger_cross_entropy +from liger_kernel.transformers.layer_norm import LigerLayerNorm +from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN +from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP class LigerPlugin(BasePlugin): """ Plugin for LIGER integration with Axolotl. """ def get_input_args(self): return "axolotl.integrations.liger.LigerArgs" def pre_model_load(self, cfg): # ... rest of the method without the imports
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/integrations/liger/plugin.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/integrations/liger/plugin.py (2)
166-175: Excellent implementation of generic fallback patch.This implementation aligns perfectly with the PR objectives to "apply generic fused liger ce for unknown models." The error handling and logging provide good user feedback when the generic patch succeeds or fails.
51-72: Well-designed dynamic parameter inspection.The use of
inspect.signature()to dynamically determine which parameters the LIGER functions support is elegant and maintains compatibility across different LIGER versions. The conditional parameter passing based on signature inspection is a robust approach.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (4)
src/axolotl/monkeypatch/tiled_mlp.py (1)
48-53: Consider moving parameter caching to initialization time.While the lazy initialization works functionally, it adds overhead to the first forward pass and could have thread safety issues in multi-threaded environments.
Consider initializing
_compute_paramsimmediately after patching the forward method:mlp_cls.forward = tiled_mlp_forward - mlp_cls._compute_params = [] # pylint: disable=protected-access + # Initialize compute_params for all instances + original_init = mlp_cls.__init__ + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self._compute_params = [p for p in self.parameters() if p.requires_grad] + mlp_cls.__init__ = patched_initThis would eliminate the conditional check in the forward pass and ensure thread safety.
src/axolotl/integrations/cut_cross_entropy/__init__.py (3)
125-127: Consider avoiding protected member access.Directly accessing
_PATCH_OPTSviolates encapsulation. Consider using a public API if available or documenting why this access is necessary.
130-134: Improve error handling specificity.The error handling catches both
ImportErrorandAttributeErrorbut doesn't differentiate between them. Consider providing more specific error messages for each case.- except (ImportError, AttributeError) as e: - raise RuntimeError( - f"Could not import ForCausalLM class for model_type: {model_type}. " - f"Error: {str(e)}" - ) from e + except ImportError as e: + raise RuntimeError( + f"Module '{module_path}' not found for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e + except AttributeError as e: + raise RuntimeError( + f"Class '{model_cls_prefix}ForCausalLM' not found in module for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e
136-138: Document side effects of modifying global state.The method modifies the global
PATCH_FNSdictionary, which could have unintended side effects if called multiple times or from different contexts. Consider documenting this behavior or implementing safeguards.Consider checking if the patch already exists before overwriting:
if model_type not in PATCH_FNS: LOG.warning("Setting up generic cce patch for model type: %s", model_type) PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type) + else: + LOG.debug("Patch already exists for model type: %s", model_type)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/integrations/cut_cross_entropy/__init__.py(3 hunks)src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp.py(2 hunks)src/axolotl/utils/schemas/config.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/utils/schemas/config.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/loaders/patch_manager.py (1)
281-285: LGTM! Configuration parameter correctly passed through.The addition of the
use_original_mlpparameter enables flexible MLP implementation selection based on configuration, which aligns well with the PR's goal of supporting generic patches for unknown models.src/axolotl/monkeypatch/tiled_mlp.py (1)
65-65: LGTM! Proper attribute initialization.Initializing
_compute_paramsas a class attribute ensures all instances have this attribute defined, preventing AttributeError on first access.src/axolotl/integrations/cut_cross_entropy/__init__.py (2)
22-22: LGTM!Import is correctly added to support the
partialfunction usage.
84-98: LGTM!The generic patch registration is correctly placed before the standard patching process.
|
📖 Documentation Preview: https://687711e4d78eac4bfb8dbe6e--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 00571f9 |
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
src/axolotl/integrations/liger/models/base.py (1)
150-157: Fix incorrect variable reference in non-FSDP path.The function uses
self.lm_headinstead of thelm_headparameter that was passed in and potentially unwrapped.# FSDP is not used so we can read the lm_head weights and call the kernel directly return _liger_for_causal_lm_loss( - lm_head=self.lm_head, + lm_head=lm_head, hidden_states=hidden_states, hidden_size=hidden_size, labels=labels, shift_labels=shift_labels, **loss_kwargs, )
🧹 Nitpick comments (1)
src/axolotl/integrations/liger/models/base.py (1)
44-58: Consider extracting configuration handling to reduce duplication.The pattern of checking if parameter is None and falling back to config values is repeated multiple times. Consider extracting this to a helper function.
+def _resolve_config_param(param_value, config_value): + return param_value if param_value is not None else config_value + # pylint: disable=duplicate-code -output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions -) -output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states -) - -return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict -) +output_attentions = _resolve_config_param(output_attentions, self.config.output_attentions) +output_hidden_states = _resolve_config_param(output_hidden_states, self.config.output_hidden_states) +return_dict = _resolve_config_param(return_dict, self.config.use_return_dict)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
src/axolotl/integrations/cut_cross_entropy/__init__.py(3 hunks)src/axolotl/integrations/liger/__init__.py(1 hunks)src/axolotl/integrations/liger/models/base.py(1 hunks)src/axolotl/integrations/liger/plugin.py(1 hunks)src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp.py(2 hunks)src/axolotl/utils/schemas/config.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/integrations/liger/init.py
🚧 Files skipped from review as they are similar to previous changes (5)
- src/axolotl/utils/schemas/config.py
- src/axolotl/loaders/patch_manager.py
- src/axolotl/monkeypatch/tiled_mlp.py
- src/axolotl/integrations/cut_cross_entropy/init.py
- src/axolotl/integrations/liger/plugin.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/integrations/liger/models/base.py (4)
1-14: LGTM! Well-structured imports and documentation.The imports are properly organized and the module docstring clearly describes the purpose.
16-43: LGTM! Well-documented function signature.The function signature and docstring provide clear documentation of the parameters, especially the detailed explanation of
logits_to_keepparameter.
68-87: LGTM! Solid logic for handling logits computation.The slicing logic for
logits_to_keepcorrectly handles both int and tensor types, and the defaultskip_logitsbehavior is well-designed.
160-170: LGTM! Clean wrapper function.The function properly delegates to
LigerForCausalLMLosswith correct parameter mapping.
| def patch_lce_forward( | ||
| model_type, | ||
| ): | ||
| try: | ||
| # Dynamically import the module and MLP class | ||
| module_path = f"transformers.models.{model_type}.modeling_{model_type}" | ||
| model_cls_prefix = "".join( | ||
| [part.capitalize() for part in model_type.split("_")] | ||
| ) | ||
| module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) | ||
| model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") | ||
|
|
||
| model_cls.forward = lce_forward | ||
| # pylint: disable=duplicate-code | ||
| except (ImportError, AttributeError) as e: | ||
| raise RuntimeError( | ||
| f"Could not import ForCausalLM class for model_type: {model_type}. " | ||
| f"Error: {str(e)}" | ||
| ) from e |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Add validation for model_type parameter and improve error handling.
The dynamic import approach is fragile and should include input validation. Also consider caching successful imports for performance.
+_PATCHED_MODELS = set()
+
def patch_lce_forward(
model_type,
):
+ if not model_type or not isinstance(model_type, str):
+ raise ValueError(f"model_type must be a non-empty string, got: {model_type}")
+
+ # Avoid double-patching
+ if model_type in _PATCHED_MODELS:
+ return
+
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = lce_forward
+ _PATCHED_MODELS.add(model_type)
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def patch_lce_forward( | |
| model_type, | |
| ): | |
| try: | |
| # Dynamically import the module and MLP class | |
| module_path = f"transformers.models.{model_type}.modeling_{model_type}" | |
| model_cls_prefix = "".join( | |
| [part.capitalize() for part in model_type.split("_")] | |
| ) | |
| module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) | |
| model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") | |
| model_cls.forward = lce_forward | |
| # pylint: disable=duplicate-code | |
| except (ImportError, AttributeError) as e: | |
| raise RuntimeError( | |
| f"Could not import ForCausalLM class for model_type: {model_type}. " | |
| f"Error: {str(e)}" | |
| ) from e | |
| # Cache of model_types already patched to avoid re-applying | |
| _PATCHED_MODELS = set() | |
| def patch_lce_forward( | |
| model_type, | |
| ): | |
| # Validate input | |
| if not model_type or not isinstance(model_type, str): | |
| raise ValueError(f"model_type must be a non-empty string, got: {model_type}") | |
| # Avoid double-patching | |
| if model_type in _PATCHED_MODELS: | |
| return | |
| try: | |
| # Dynamically import the module and MLP class | |
| module_path = f"transformers.models.{model_type}.modeling_{model_type}" | |
| model_cls_prefix = "".join( | |
| [part.capitalize() for part in model_type.split("_")] | |
| ) | |
| module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) | |
| model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") | |
| model_cls.forward = lce_forward | |
| _PATCHED_MODELS.add(model_type) | |
| # pylint: disable=duplicate-code | |
| except (ImportError, AttributeError) as e: | |
| raise RuntimeError( | |
| f"Could not import ForCausalLM class for model_type: {model_type}. " | |
| f"Error: {str(e)}" | |
| ) from e |
🤖 Prompt for AI Agents
In src/axolotl/integrations/liger/models/base.py around lines 173 to 191, add
validation to check that the model_type parameter is a non-empty string and
matches expected patterns before attempting the dynamic import to prevent import
errors. Enhance error handling by providing clearer messages and consider
implementing a caching mechanism (e.g., a dictionary) to store and reuse
successfully imported model classes to improve performance and avoid repeated
imports.
| if skip_logits: | ||
| loss = lce_maybe_trainable_lm_head( | ||
| self, | ||
| hidden_states=kept_hidden_states, | ||
| hidden_size=self.config.hidden_size, | ||
| labels=labels, | ||
| shift_labels=shift_labels, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| else: | ||
| logits = self.lm_head(kept_hidden_states) | ||
| if labels is not None: | ||
| loss = self.loss_function( | ||
| logits=logits, | ||
| labels=labels, | ||
| vocab_size=self.config.vocab_size, | ||
| **kwargs, | ||
| ) | ||
|
|
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Ensure loss_function is available when labels are provided.
The code assumes self.loss_function exists when computing loss in the else branch. Verify that all target models have this attribute or add error handling.
🏁 Script executed:
#!/bin/bash
# Verify if loss_function is a standard attribute in transformers causal LM models
ast-grep --pattern 'class $_ {
$$$
loss_function = $$$
$$$
}' | head -20
# Also check for loss_function method definitions
ast-grep --pattern 'def loss_function($$$) {
$$$
}' | head -20Length of output: 147
🏁 Script executed:
#!/bin/bash
set -e
echo "Searching for loss_function definitions, assignments, and references across the repo..."
# Definitions of loss_function methods
rg -n "def loss_function" .
# Attribute assignments to loss_function
rg -n "loss_function\s*=" .
# References to self.loss_function
rg -n "self\.loss_function" .Length of output: 223
Ensure loss_function is defined or add clear error handling
In the else branch you call self.loss_function(...) if labels is set, but no such method or attribute exists in this class (or its base classes)—this will raise an unexpected AttributeError at runtime. Please either provide a default implementation for loss_function or guard against its absence.
• File: src/axolotl/integrations/liger/models/base.py
• Lines: ~89–108
Example patch:
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None:
+ if not hasattr(self, 'loss_function'):
+ raise AttributeError(
+ "Model must define `loss_function` to compute loss with labels"
+ )
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if skip_logits: | |
| loss = lce_maybe_trainable_lm_head( | |
| self, | |
| hidden_states=kept_hidden_states, | |
| hidden_size=self.config.hidden_size, | |
| labels=labels, | |
| shift_labels=shift_labels, | |
| **kwargs, | |
| ) | |
| else: | |
| logits = self.lm_head(kept_hidden_states) | |
| if labels is not None: | |
| loss = self.loss_function( | |
| logits=logits, | |
| labels=labels, | |
| vocab_size=self.config.vocab_size, | |
| **kwargs, | |
| ) | |
| if skip_logits: | |
| loss = lce_maybe_trainable_lm_head( | |
| self, | |
| hidden_states=kept_hidden_states, | |
| hidden_size=self.config.hidden_size, | |
| labels=labels, | |
| shift_labels=shift_labels, | |
| **kwargs, | |
| ) | |
| else: | |
| logits = self.lm_head(kept_hidden_states) | |
| if labels is not None: | |
| if not hasattr(self, 'loss_function'): | |
| raise AttributeError( | |
| "Model must define `loss_function` to compute loss with labels" | |
| ) | |
| loss = self.loss_function( | |
| logits=logits, | |
| labels=labels, | |
| vocab_size=self.config.vocab_size, | |
| **kwargs, | |
| ) |
🤖 Prompt for AI Agents
In src/axolotl/integrations/liger/models/base.py around lines 89 to 108, the
code calls self.loss_function when labels are provided, but this method or
attribute is not defined in the class or its base classes, which will cause an
AttributeError at runtime. To fix this, either implement a default loss_function
method in the class or add a check before calling it to ensure it exists,
raising a clear error or handling the case when it is absent.
djsaunde
left a comment
There was a problem hiding this comment.
Cool stuff!
Some questions / comments but otherwise looks good.
| If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. | ||
| This is useful when using packed tensor format (single dimension for batch and sequence length). | ||
|
|
||
| Returns: |
There was a problem hiding this comment.
going to drop this, the annotation lists the types, and upstream in liger, they just have a code sample there 🤷
|
|
||
| def lce_forward( | ||
| self, | ||
| *args, |
There was a problem hiding this comment.
Do we need to support possibly variable-length args here?
There was a problem hiding this comment.
I was trying to seperate what's required from not, so generally all non-kwargs were required.
| if not self._compute_params: # pylint: disable=protected-access | ||
| self._compute_params = [ # pylint: disable=protected-access | ||
| p for p in self.parameters() if p.requires_grad | ||
| ] | ||
|
|
||
| compute_params = self._compute_params # pylint: disable=protected-access |
There was a problem hiding this comment.
Are these pylint disables needed? We're accessing class members.
There was a problem hiding this comment.
yeah, it was giving me some errors when I was linting
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (4)
src/axolotl/integrations/liger/plugin.py (1)
108-111: Fix conflicting GLU activation logic for DeepseekV2.The code warns that
liger_glu_activationis not supported for DeepseekV2 but then applies it anyway. This is contradictory behavior.Either remove the warning or skip the GLU activation patching:
if cfg.liger_glu_activation: LOG.warning("liger_glu_activation is not supported for DeepseekV2.") if cfg.liger_rms_norm: modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm - if cfg.liger_glu_activation: - modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forwardsrc/axolotl/integrations/liger/models/base.py (3)
98-106: Ensureloss_functionis defined or add clear error handlingIn the
elsebranch you callself.loss_function(...)iflabelsis set, but no such method or attribute exists in this class (or its base classes)—this will raise an unexpectedAttributeErrorat runtime. Please either provide a default implementation forloss_functionor guard against its absence.Example patch:
else: logits = self.lm_head(kept_hidden_states) if labels is not None: + if not hasattr(self, 'loss_function'): + raise AttributeError( + "Model must define `loss_function` to compute loss with labels" + ) loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs, )
148-155: Fix incorrect variable reference in non-FSDP path.The function uses
self.lm_headinstead of thelm_headparameter that was passed in.# FSDP is not used so we can read the lm_head weights and call the kernel directly return _liger_for_causal_lm_loss( - lm_head=self.lm_head, + lm_head=lm_head, hidden_states=hidden_states, hidden_size=hidden_size, labels=labels, shift_labels=shift_labels, **loss_kwargs, )
171-189: Add validation for model_type parameter and improve error handling.The dynamic import approach is fragile and should include input validation. Also consider caching successful imports for performance.
+_PATCHED_MODELS = set() + def patch_lce_forward( model_type, ): + if not model_type or not isinstance(model_type, str): + raise ValueError(f"model_type must be a non-empty string, got: {model_type}") + + # Avoid double-patching + if model_type in _PATCHED_MODELS: + return + try: # Dynamically import the module and MLP class module_path = f"transformers.models.{model_type}.modeling_{model_type}" model_cls_prefix = "".join( [part.capitalize() for part in model_type.split("_")] ) module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") model_cls.forward = lce_forward + _PATCHED_MODELS.add(model_type) # pylint: disable=duplicate-code except (ImportError, AttributeError) as e: raise RuntimeError( f"Could not import ForCausalLM class for model_type: {model_type}. " f"Error: {str(e)}" ) from e
🧹 Nitpick comments (1)
src/axolotl/integrations/liger/plugin.py (1)
169-171: Fix typo in warning message.There's a typo in the warning message.
LOG.warning_once( - f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}" + f"Applied ONLY liger_fused_linear_cross_entropy generic patches for model type: {cfg.model_config_type}" )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/integrations/cut_cross_entropy/__init__.py(3 hunks)src/axolotl/integrations/liger/__init__.py(1 hunks)src/axolotl/integrations/liger/models/base.py(1 hunks)src/axolotl/integrations/liger/plugin.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/integrations/cut_cross_entropy/init.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (1)
src/axolotl/integrations/liger/__init__.py (1)
24-27: Great implementation of the suggested export pattern!The addition of
__all__properly declares the public API and allows removal of the linter disables as suggested in the past review.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (5)
src/axolotl/integrations/cut_cross_entropy/__init__.py (1)
109-111: Address the unused parameter issue.The
maybe_modelparameter is not used in the function body and should be removed as noted in previous reviews.src/axolotl/integrations/liger/plugin.py (1)
110-111: Address the conflicting GLU activation logic for DeepseekV2.Based on past review comments, there's conflicting behavior where the code warns that
liger_glu_activationis not supported for DeepseekV2 but then applies it anyway on line 111.src/axolotl/integrations/liger/models/base.py (3)
89-108: Missingloss_functionattribute handling remains unresolved.The code calls
self.loss_functionwhen labels are provided, but this method/attribute is not guaranteed to exist in the target models, which will cause anAttributeErrorat runtime.
150-157: Incorrect variable reference remains unresolved.The function still uses
self.lm_headinstead of thelm_headparameter that was passed in and potentially unwrapped by PEFT handling.
173-189: Missing validation and caching mechanism remains unresolved.The dynamic import approach is still fragile and lacks input validation for
model_typeparameter. The suggestion to add caching to prevent double-patching has not been implemented.
🧹 Nitpick comments (1)
src/axolotl/integrations/liger/plugin.py (1)
170-170: Fix typo in log message.There's a typo in the log message: "genericpatches" should be "generic patches".
- f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}" + f"Applied ONLY liger_fused_linear_cross_entropy generic patches for model type: {cfg.model_config_type}"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
setup.py(1 hunks)src/axolotl/integrations/cut_cross_entropy/__init__.py(3 hunks)src/axolotl/integrations/kd/kernels/models.py(2 hunks)src/axolotl/integrations/liger/__init__.py(1 hunks)src/axolotl/integrations/liger/models/base.py(1 hunks)src/axolotl/integrations/liger/plugin.py(1 hunks)src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/lora_kernels.py(2 hunks)src/axolotl/monkeypatch/tiled_mlp.py(3 hunks)src/axolotl/utils/callbacks/models.py(1 hunks)src/axolotl/utils/schemas/config.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- src/axolotl/utils/schemas/config.py
- src/axolotl/loaders/patch_manager.py
- src/axolotl/integrations/liger/init.py
- src/axolotl/monkeypatch/tiled_mlp.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
src/axolotl/integrations/kd/kernels/models.py (1)
src/axolotl/utils/callbacks/models.py (1)
get_causal_lm_model_cls_prefix(8-23)
src/axolotl/monkeypatch/lora_kernels.py (1)
src/axolotl/utils/callbacks/models.py (1)
get_causal_lm_model_cls_prefix(8-23)
src/axolotl/integrations/liger/models/base.py (2)
src/axolotl/utils/callbacks/models.py (1)
get_causal_lm_model_cls_prefix(8-23)src/axolotl/integrations/kd/topk_logprob/forward_kl.py (1)
loss(23-96)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (15)
setup.py (1)
117-117: Flash-attn version constraint is compatible and secureThe minimum requirement
flash-attn>=2.7.4.post1aligns with the latest PyPI release (2.8.1) and there are no published security advisories for this package. The existing code and tests already import and exercise flash-attn APIs without issue.No further changes needed here—this relaxation is safe.
src/axolotl/monkeypatch/lora_kernels.py (2)
21-21: LGTM: Consistent import addition.The import of
get_causal_lm_model_cls_prefixaligns with the centralization effort across the codebase.
157-157: LGTM: Improved consistency with centralized helper.Replacing the manual prefix construction with
get_causal_lm_model_cls_prefiximproves code consistency and maintainability across the codebase.src/axolotl/integrations/kd/kernels/models.py (2)
25-26: LGTM: Consistent centralization of model prefix logic.The import and usage of
get_causal_lm_model_cls_prefiximproves consistency across the codebase.
102-102: LGTM: Refactored to use centralized helper.The change from manual prefix construction to using the centralized helper function improves maintainability.
src/axolotl/integrations/cut_cross_entropy/__init__.py (5)
22-22: LGTM: Added import for partial function.The
partialimport is properly added to support the new patching functionality.
30-30: LGTM: Consistent use of centralized helper.The import of
get_causal_lm_model_cls_prefixmaintains consistency with the codebase refactoring.
89-89: LGTM: Good integration of generic patching.The call to
patch_llama_likebefore the standardcce_patchprovides a good fallback mechanism for unsupported model types.
115-134: LGTM: Robust dynamic import with proper error handling.The dynamic import logic correctly uses the centralized helper function and includes proper error handling with descriptive error messages. The try-catch block appropriately handles
ImportErrorandAttributeErrorcases.
136-143: LGTM: Well-structured conditional patching with appropriate warnings.The logic correctly checks if the model type already has a patch function and only registers the generic patch when needed. The warning messages appropriately inform users about the experimental nature of generic support.
src/axolotl/integrations/liger/models/base.py (5)
1-16: LGTM! Well-organized imports and good use of utility functions.The imports are appropriate for the functionality and the use of
get_causal_lm_model_cls_prefixaligns well with the centralized approach for dynamic model class handling.
18-88: LGTM! Well-structured forward method with good parameter handling.The function properly handles output flags, conditional logits computation, and the
skip_logitslogic. The parameter validation and default value assignment are implemented correctly.
109-119: LGTM! Proper return value handling.The function correctly handles both tuple and dictionary return formats based on the
return_dictparameter.
122-149: LGTM! Proper PEFT and FSDP integration handling.The function correctly unwraps PEFT
ModulesToSaveWrapperand handles FSDP forward redirection with appropriate parameter passing.
160-170: LGTM! Clean wrapper function.The function provides a clear interface to
LigerForCausalLMLosswith appropriate parameter mapping.
Description
While Liger doesn't have patches for all models, we know most models have model forwards very similar to llama, so let's patch models with a generic flce if possible and log.
This is useful for experimenting if it works out of the box for newer models like LFM2
Summary by CodeRabbit
New Features
Improvements
Bug Fixes