From 9cac50b657472b3a536d2a096b849a58703a9808 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 11 Jul 2025 22:54:37 -0400 Subject: [PATCH 01/10] Apply generic fused liger ce for unknown models --- src/axolotl/integrations/liger/__init__.py | 167 +-------------- src/axolotl/integrations/liger/models/base.py | 190 ++++++++++++++++++ src/axolotl/integrations/liger/plugin.py | 181 +++++++++++++++++ 3 files changed, 372 insertions(+), 166 deletions(-) create mode 100644 src/axolotl/integrations/liger/models/base.py create mode 100644 src/axolotl/integrations/liger/plugin.py diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 8de94c78be..672b3a1ff7 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -18,170 +18,5 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training. It is designed to be performant, correct, and light-weight. """ -import inspect -import sys - -from axolotl.integrations.base import BasePlugin -from axolotl.utils.logging import get_logger - from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 -from .utils import patch_with_compile_disable - -LOG = get_logger(__name__) - - -class LigerPlugin(BasePlugin): - """ - Plugin for LIGER integraton with Axolotl. - """ - - def get_input_args(self): - return "axolotl.integrations.liger.LigerArgs" - - def pre_model_load(self, cfg): - if cfg.torch_compile: - # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled - import liger_kernel.ops.fused_linear_cross_entropy - - patch_with_compile_disable( - liger_kernel.ops.fused_linear_cross_entropy, - "fused_linear_cross_entropy_forward", - ) - patch_with_compile_disable( - liger_kernel.ops.fused_linear_cross_entropy, - "fused_linear_cross_entropy_backward", - ) - from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss - from liger_kernel.transformers.functional import liger_cross_entropy - from liger_kernel.transformers.layer_norm import LigerLayerNorm - from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN - from liger_kernel.transformers.rms_norm import LigerRMSNorm - from liger_kernel.transformers.rope import liger_rotary_pos_emb - from liger_kernel.transformers.swiglu import LigerSwiGLUMLP - - if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy: - raise ValueError( - "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." - ) - - if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: - apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] - liger_fn_sig = inspect.signature(apply_liger_fn) - kwargs = {} - if "rope" in liger_fn_sig.parameters: - kwargs["rope"] = cfg.liger_rope - if "cross_entropy" in liger_fn_sig.parameters: - kwargs["cross_entropy"] = cfg.liger_cross_entropy - if "fused_linear_cross_entropy" in liger_fn_sig.parameters: - kwargs["fused_linear_cross_entropy"] = ( - cfg.liger_fused_linear_cross_entropy - ) - if "rms_norm" in liger_fn_sig.parameters: - kwargs["rms_norm"] = cfg.liger_rms_norm - if "layer_norm" in liger_fn_sig.parameters: - kwargs["layer_norm"] = cfg.liger_layer_norm - if "geglu" in liger_fn_sig.parameters: - kwargs["geglu"] = cfg.liger_glu_activation - elif "swiglu" in liger_fn_sig.parameters: - kwargs["swiglu"] = cfg.liger_glu_activation - LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}") - apply_liger_fn(**kwargs) - elif cfg.model_config_type == "jamba": - from transformers.models.jamba import modeling_jamba - - from .models.jamba import lce_forward as jamba_lce_forward - - if cfg.liger_rope: - modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_jamba.JambaRMSNorm = LigerRMSNorm - if cfg.liger_glu_activation: - modeling_jamba.JambaMLP = LigerSwiGLUMLP - if cfg.liger_layer_norm: - modeling_jamba.nn.LayerNorm = LigerLayerNorm - if cfg.liger_cross_entropy: - from transformers.loss.loss_utils import nn - - nn.functional.cross_entropy = liger_cross_entropy - if cfg.liger_fused_linear_cross_entropy: - modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward - elif cfg.model_config_type == "deepseek_v2": - from accelerate import init_empty_weights - from transformers import AutoModelForCausalLM - - with init_empty_weights(): - model = AutoModelForCausalLM.from_pretrained( - cfg.base_model, trust_remote_code=cfg.trust_remote_code or False - ) - modeling_mod = sys.modules[model.__class__.__module__] - - from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward - - if cfg.liger_rope: - # The DeepseekV2 version of RoPE is different than upstream LLaMA. - # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 - LOG.warning("Fused liger_rope is not supported for DeepseekV2.") - if cfg.liger_glu_activation: - LOG.warning("liger_glu_activation is not supported for DeepseekV2.") - if cfg.liger_rms_norm: - modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm - if cfg.liger_glu_activation: - modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward - if cfg.liger_layer_norm: - modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward - if cfg.liger_cross_entropy: - # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses - # nn.CrossEntropyLoss in the forward method. - modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss - if cfg.liger_fused_linear_cross_entropy: - modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward - elif cfg.model_config_type == "llama4": - from axolotl.integrations.liger.models.llama4 import ( - apply_liger_kernel_to_llama4, - ) - - apply_liger_kernel_to_llama4( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "qwen3": - from axolotl.integrations.liger.models.qwen3 import ( - apply_liger_kernel_to_qwen3, - ) - - apply_liger_kernel_to_qwen3( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "qwen3_moe": - from axolotl.integrations.liger.models.qwen3_moe import ( - apply_liger_kernel_to_qwen3_moe, - ) - - apply_liger_kernel_to_qwen3_moe( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "granitemoe": - from liger_kernel.transformers import apply_liger_kernel_to_granite - - apply_liger_kernel_to_granite( - rope=cfg.liger_rope, - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - rms_norm=cfg.liger_rms_norm, - swiglu=cfg.liger_glu_activation, - ) - else: - LOG.warning( - f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." - ) +from .plugin import LigerPlugin # pylint: disable=unused-import. # noqa: F401 diff --git a/src/axolotl/integrations/liger/models/base.py b/src/axolotl/integrations/liger/models/base.py new file mode 100644 index 0000000000..66fe34182f --- /dev/null +++ b/src/axolotl/integrations/liger/models/base.py @@ -0,0 +1,190 @@ +""" +Generic FLCE patch for untested models similar to Llama +""" + +from typing import Optional, Tuple, Union + +import torch +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection +from liger_kernel.utils import PEFT_AVAILABLE +from peft.utils import ModulesToSaveWrapper +from torch.distributed.fsdp import FullyShardedDataParallel +from transformers.modeling_outputs import CausalLMOutputWithPast + + +def lce_forward( + self, + *args, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + """ + + # pylint: disable=duplicate-code + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + *args, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + + # if in training mode, don't materialize logits + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + loss = lce_maybe_trainable_lm_head( + self, + hidden_states=kept_hidden_states, + hidden_size=self.config.hidden_size, + labels=labels, + shift_labels=shift_labels, + **kwargs, + ) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def lce_maybe_trainable_lm_head( + self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs +): + lm_head = self.lm_head + + # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration, + # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read + # from the unwrapped module. + # See https://huggingface.co/docs/peft/package_reference/lora for reference. + if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper): + lm_head = lm_head.modules_to_save.default + + # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA, + # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass + # so the module entire parameters are summoned and kept in memory during the kernel execution. + if isinstance(lm_head, FullyShardedDataParallel): + return _FSDPForwardRedirection()( + lm_head, + _liger_for_causal_lm_loss, + lm_head.module, + hidden_states, + hidden_size, + labels, + shift_labels, + **loss_kwargs, + ) + + # FSDP is not used so we can read the lm_head weights and call the kernel directly + return _liger_for_causal_lm_loss( + lm_head=self.lm_head, + hidden_states=hidden_states, + hidden_size=hidden_size, + labels=labels, + shift_labels=shift_labels, + **loss_kwargs, + ) + + +def _liger_for_causal_lm_loss( + lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs +): + return LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=lm_head.weight, + labels=labels, + hidden_size=hidden_size, + shift_labels=shift_labels, + **loss_kwargs, + ) + + +def patch_lce_forward( + model_type, +): + try: + # Dynamically import the module and MLP class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix = "".join( + [part.capitalize() for part in model_type.split("_")] + ) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + + model_cls.forward = lce_forward + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import ForCausalLM class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py new file mode 100644 index 0000000000..11a0b5d553 --- /dev/null +++ b/src/axolotl/integrations/liger/plugin.py @@ -0,0 +1,181 @@ +""" +Liger-Kernel Plugin for Axolotl +""" + +import inspect +import sys + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger + +from .models.base import patch_lce_forward +from .utils import patch_with_compile_disable + +LOG = get_logger(__name__) + + +class LigerPlugin(BasePlugin): + """ + Plugin for LIGER integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.liger.LigerArgs" + + def pre_model_load(self, cfg): + if cfg.torch_compile: + # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled + import liger_kernel.ops.fused_linear_cross_entropy + + patch_with_compile_disable( + liger_kernel.ops.fused_linear_cross_entropy, + "fused_linear_cross_entropy_forward", + ) + patch_with_compile_disable( + liger_kernel.ops.fused_linear_cross_entropy, + "fused_linear_cross_entropy_backward", + ) + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + from liger_kernel.transformers.functional import liger_cross_entropy + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.rope import liger_rotary_pos_emb + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + + if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy: + raise ValueError( + "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." + ) + + if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: + apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] + liger_fn_sig = inspect.signature(apply_liger_fn) + kwargs = {} + if "rope" in liger_fn_sig.parameters: + kwargs["rope"] = cfg.liger_rope + if "cross_entropy" in liger_fn_sig.parameters: + kwargs["cross_entropy"] = cfg.liger_cross_entropy + if "fused_linear_cross_entropy" in liger_fn_sig.parameters: + kwargs["fused_linear_cross_entropy"] = ( + cfg.liger_fused_linear_cross_entropy + ) + if "rms_norm" in liger_fn_sig.parameters: + kwargs["rms_norm"] = cfg.liger_rms_norm + if "layer_norm" in liger_fn_sig.parameters: + kwargs["layer_norm"] = cfg.liger_layer_norm + if "geglu" in liger_fn_sig.parameters: + kwargs["geglu"] = cfg.liger_glu_activation + elif "swiglu" in liger_fn_sig.parameters: + kwargs["swiglu"] = cfg.liger_glu_activation + LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}") + apply_liger_fn(**kwargs) + elif cfg.model_config_type == "jamba": + from transformers.models.jamba import modeling_jamba + + from .models.jamba import lce_forward as jamba_lce_forward + + if cfg.liger_rope: + modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_jamba.JambaRMSNorm = LigerRMSNorm + if cfg.liger_glu_activation: + modeling_jamba.JambaMLP = LigerSwiGLUMLP + if cfg.liger_layer_norm: + modeling_jamba.nn.LayerNorm = LigerLayerNorm + if cfg.liger_cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if cfg.liger_fused_linear_cross_entropy: + modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward + elif cfg.model_config_type == "deepseek_v2": + from accelerate import init_empty_weights + from transformers import AutoModelForCausalLM + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + cfg.base_model, trust_remote_code=cfg.trust_remote_code or False + ) + modeling_mod = sys.modules[model.__class__.__module__] + + from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward + + if cfg.liger_rope: + # The DeepseekV2 version of RoPE is different than upstream LLaMA. + # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 + LOG.warning("Fused liger_rope is not supported for DeepseekV2.") + if cfg.liger_glu_activation: + LOG.warning("liger_glu_activation is not supported for DeepseekV2.") + if cfg.liger_rms_norm: + modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm + if cfg.liger_glu_activation: + modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward + if cfg.liger_layer_norm: + modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward + if cfg.liger_cross_entropy: + # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses + # nn.CrossEntropyLoss in the forward method. + modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward + elif cfg.model_config_type == "llama4": + from axolotl.integrations.liger.models.llama4 import ( + apply_liger_kernel_to_llama4, + ) + + apply_liger_kernel_to_llama4( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "qwen3": + from axolotl.integrations.liger.models.qwen3 import ( + apply_liger_kernel_to_qwen3, + ) + + apply_liger_kernel_to_qwen3( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "qwen3_moe": + from axolotl.integrations.liger.models.qwen3_moe import ( + apply_liger_kernel_to_qwen3_moe, + ) + + apply_liger_kernel_to_qwen3_moe( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "granitemoe": + from liger_kernel.transformers import apply_liger_kernel_to_granite + + apply_liger_kernel_to_granite( + rope=cfg.liger_rope, + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + rms_norm=cfg.liger_rms_norm, + swiglu=cfg.liger_glu_activation, + ) + elif cfg.liger_fused_linear_cross_entropy: + try: + patch_lce_forward(cfg.model_config_type) + LOG.info( + f"Applied ONLY liger_fused_linear_cross_entropy patches for model type: {cfg.model_config_type}" + ) + except RuntimeError: + LOG.warning( + f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." + ) + else: + LOG.warning( + f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." + ) From b403616c3c54f3fafefa4695ac12c5f43e437752 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Jul 2025 00:03:52 -0400 Subject: [PATCH 02/10] fix deepseek liger modeling --- src/axolotl/integrations/liger/plugin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index 11a0b5d553..f3ac9f72b9 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -105,14 +105,12 @@ def pre_model_load(self, cfg): # The DeepseekV2 version of RoPE is different than upstream LLaMA. # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 LOG.warning("Fused liger_rope is not supported for DeepseekV2.") - if cfg.liger_glu_activation: - LOG.warning("liger_glu_activation is not supported for DeepseekV2.") if cfg.liger_rms_norm: modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm if cfg.liger_glu_activation: modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward if cfg.liger_layer_norm: - modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward + LOG.warning("liger_layer_norm is not supported for DeepseekV2.") if cfg.liger_cross_entropy: # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses # nn.CrossEntropyLoss in the forward method. From 58c92a10061408fd28f13df97cb4a16d5a294acf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Jul 2025 01:38:29 -0400 Subject: [PATCH 03/10] generic cce and config tiled mlp to use original mlp and auto detect compute params --- .../cut_cross_entropy/__init__.py | 43 +++++++++++++++++++ src/axolotl/loaders/patch_manager.py | 6 ++- src/axolotl/monkeypatch/tiled_mlp.py | 12 +++--- src/axolotl/utils/schemas/config.py | 7 +++ 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index a2f0d52d75..69adc0ea9f 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -19,6 +19,7 @@ from Apple's ML team. """ import importlib +from functools import partial import torch @@ -84,6 +85,7 @@ def pre_model_load(self, cfg): """Apply cut cross entropy before model loading if enabled.""" if cfg.cut_cross_entropy: self._check_requirements() + self.patch_llama_like(cfg.model_config_type) from cut_cross_entropy.transformers.patch import cce_patch @@ -93,3 +95,44 @@ def pre_model_load(self, cfg): # The patch checks model_type internally cce_patch(cfg.model_config_type) + + def patch_llama_like( + self, + model_type: str, + ) -> None: + """ + Generic patch for model architectures with causal lm similar to llama + """ + from cut_cross_entropy.transformers.patch import PATCH_FNS + + def patch_generic( + maybe_model, patch_options, model_type: str + ): # pylint: disable=unused-argument + import cut_cross_entropy.transformers.llama + from cut_cross_entropy.transformers.llama import cce_forward + + try: + # Dynamically import the module and CausalLM class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix = "".join( + [part.capitalize() for part in model_type.split("_")] + ) + module = __import__( + module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"] + ) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + + cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access + patch_options + ) + + model_cls.forward = cce_forward + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import ForCausalLM class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e + + if model_type not in PATCH_FNS: + LOG.warning("Setting up generic cce patch for model type: %s", model_type) + PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 84e6b33def..f346c56e04 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -272,7 +272,11 @@ def _apply_tiled_mlp(self, model_type: str): if self.cfg.tiled_mlp: from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp - patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards) + patch_tiled_mlp( + model_type, + use_original_mlp=self.cfg.tiled_mlp_use_original_mlp, + cfg_num_shards=self.cfg.tiled_mlp_num_shards, + ) def _patch_attention(self): """Apply attention-specific patches based on model type.""" diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp.py index 99a10df9c0..8b509101fc 100644 --- a/src/axolotl/monkeypatch/tiled_mlp.py +++ b/src/axolotl/monkeypatch/tiled_mlp.py @@ -45,11 +45,12 @@ def tiled_mlp_forward(self, x): else: num_shards = cfg_num_shards - compute_params = [ - self.down_proj.weight, - self.gate_proj.weight, - self.up_proj.weight, - ] + if not self._compute_params: # pylint: disable=protected-access + self._compute_params = [ # pylint: disable=protected-access + p.weight for p in self.parameters() if p.requires_grad + ] + + compute_params = self._compute_params # pylint: disable=protected-access down_res = TiledMLP.apply( mlp_forward, @@ -61,6 +62,7 @@ def tiled_mlp_forward(self, x): return down_res mlp_cls.forward = tiled_mlp_forward + mlp_cls._compute_params = [] # pylint: disable=protected-access except (ImportError, AttributeError) as e: raise RuntimeError( f"Could not import MLP class for model_type: {model_type}. " diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e20cdaf47b..06212a27f0 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -576,6 +576,13 @@ class AxolotlInputConfig( }, ) + tiled_mlp_use_original_mlp: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama." + }, + ) + llama4_linearized_experts: bool | None = None deepspeed: str | dict[str, Any] | None = Field( From 0e512cb1df9d04697e47dd8a4a7d8c825aed6c39 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Jul 2025 01:50:41 -0400 Subject: [PATCH 04/10] fix weight and lint --- src/axolotl/integrations/cut_cross_entropy/__init__.py | 1 + src/axolotl/integrations/liger/models/base.py | 1 + src/axolotl/monkeypatch/tiled_mlp.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 69adc0ea9f..265f85f95e 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -127,6 +127,7 @@ def patch_generic( ) model_cls.forward = cce_forward + # pylint: disable=duplicate-code except (ImportError, AttributeError) as e: raise RuntimeError( f"Could not import ForCausalLM class for model_type: {model_type}. " diff --git a/src/axolotl/integrations/liger/models/base.py b/src/axolotl/integrations/liger/models/base.py index 66fe34182f..48ac55f3aa 100644 --- a/src/axolotl/integrations/liger/models/base.py +++ b/src/axolotl/integrations/liger/models/base.py @@ -183,6 +183,7 @@ def patch_lce_forward( model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") model_cls.forward = lce_forward + # pylint: disable=duplicate-code except (ImportError, AttributeError) as e: raise RuntimeError( f"Could not import ForCausalLM class for model_type: {model_type}. " diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp.py index 8b509101fc..cf06eaba4b 100644 --- a/src/axolotl/monkeypatch/tiled_mlp.py +++ b/src/axolotl/monkeypatch/tiled_mlp.py @@ -47,7 +47,7 @@ def tiled_mlp_forward(self, x): if not self._compute_params: # pylint: disable=protected-access self._compute_params = [ # pylint: disable=protected-access - p.weight for p in self.parameters() if p.requires_grad + p for p in self.parameters() if p.requires_grad ] compute_params = self._compute_params # pylint: disable=protected-access From 3220217afe5d27b1161e5172e365bbe415ad7b50 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Jul 2025 14:59:50 -0400 Subject: [PATCH 05/10] update warnings --- src/axolotl/integrations/cut_cross_entropy/__init__.py | 7 ++++++- src/axolotl/integrations/liger/plugin.py | 7 +++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 265f85f95e..c1c0000ef1 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -135,5 +135,10 @@ def patch_generic( ) from e if model_type not in PATCH_FNS: - LOG.warning("Setting up generic cce patch for model type: %s", model_type) + LOG.warning_once( + "Setting up generic cce patch for model type: %s", model_type + ) + LOG.warning_once( + f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected." + ) PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type) diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index f3ac9f72b9..89f7c37b71 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -166,8 +166,11 @@ def pre_model_load(self, cfg): elif cfg.liger_fused_linear_cross_entropy: try: patch_lce_forward(cfg.model_config_type) - LOG.info( - f"Applied ONLY liger_fused_linear_cross_entropy patches for model type: {cfg.model_config_type}" + LOG.warning_once( + f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}" + ) + LOG.warning_once( + f"Liger + {cfg.model_config_type} generic FLCE support is experimental and may not work as expected." ) except RuntimeError: LOG.warning( From 5e935eb794345d8c03f15c02b79880561e73a69e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 14 Jul 2025 12:00:44 -0400 Subject: [PATCH 06/10] address PR feedback --- src/axolotl/integrations/liger/__init__.py | 5 +++++ src/axolotl/integrations/liger/models/base.py | 2 -- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 672b3a1ff7..bcdb983d4c 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -20,3 +20,8 @@ """ from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .plugin import LigerPlugin # pylint: disable=unused-import. # noqa: F401 + +__all__ = [ + "LigerArgs", + "LigerPlugin", +] diff --git a/src/axolotl/integrations/liger/models/base.py b/src/axolotl/integrations/liger/models/base.py index 48ac55f3aa..2dea8ce05e 100644 --- a/src/axolotl/integrations/liger/models/base.py +++ b/src/axolotl/integrations/liger/models/base.py @@ -37,8 +37,6 @@ def lce_forward( token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: """ # pylint: disable=duplicate-code From 2ce0cf3594f1a22d0b9a7a84027d772c9b0d6419 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 15 Jul 2025 21:02:19 -0400 Subject: [PATCH 07/10] use lookup for model class prefixes --- setup.py | 4 ++-- .../cut_cross_entropy/__init__.py | 5 ++-- src/axolotl/integrations/kd/kernels/models.py | 4 +++- src/axolotl/integrations/liger/models/base.py | 6 ++--- src/axolotl/monkeypatch/lora_kernels.py | 5 ++-- src/axolotl/monkeypatch/tiled_mlp.py | 6 ++--- src/axolotl/utils/callbacks/models.py | 23 +++++++++++++++++++ 7 files changed, 38 insertions(+), 15 deletions(-) create mode 100644 src/axolotl/utils/callbacks/models.py diff --git a/setup.py b/setup.py index df9a231544..9606fd0afa 100644 --- a/setup.py +++ b/setup.py @@ -114,9 +114,9 @@ def get_package_version(): extras_require = { - "flash-attn": ["flash-attn==2.8.0.post2"], + "flash-attn": ["flash-attn>=2.7.4.post1"], "ring-flash-attn": [ - "flash-attn==2.8.0.post2", + "flash-attn>=2.7.4.post1", "ring-flash-attn>=0.1.5", "yunchang==0.6.0", ], diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index c1c0000ef1..699ca4033d 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -27,6 +27,7 @@ from axolotl.utils import get_pytorch_version from axolotl.utils.logging import get_logger +from ...utils.callbacks.models import get_causal_lm_model_cls_prefix from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 LOG = get_logger(__name__) @@ -114,9 +115,7 @@ def patch_generic( try: # Dynamically import the module and CausalLM class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join( - [part.capitalize() for part in model_type.split("_")] - ) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__( module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"] ) diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py index 6a8b6da1c9..4319f5f7dd 100644 --- a/src/axolotl/integrations/kd/kernels/models.py +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -22,6 +22,8 @@ class TransformersKwargs(FlashAttentionKwargs, LossKwargs): TransformersKwargs, ) +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + def kldiv_forward_llama_like( self, @@ -97,7 +99,7 @@ def kldiv_forward_llama_like( def apply_kernel(model_type): # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")]) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") model_cls.forward = kldiv_forward_llama_like diff --git a/src/axolotl/integrations/liger/models/base.py b/src/axolotl/integrations/liger/models/base.py index 2dea8ce05e..f3cf4299ad 100644 --- a/src/axolotl/integrations/liger/models/base.py +++ b/src/axolotl/integrations/liger/models/base.py @@ -12,6 +12,8 @@ from torch.distributed.fsdp import FullyShardedDataParallel from transformers.modeling_outputs import CausalLMOutputWithPast +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + def lce_forward( self, @@ -174,9 +176,7 @@ def patch_lce_forward( try: # Dynamically import the module and MLP class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join( - [part.capitalize() for part in model_type.split("_")] - ) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 586412dd78..4702ad19d6 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -18,6 +18,7 @@ apply_lora_qkv, ) from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -153,9 +154,7 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: try: # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join( - [part.capitalize() for part in model_type.split("_")] - ) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) attention_cls = getattr(module, f"{model_cls_prefix}Attention") diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp.py index cf06eaba4b..3818c6b355 100644 --- a/src/axolotl/monkeypatch/tiled_mlp.py +++ b/src/axolotl/monkeypatch/tiled_mlp.py @@ -6,6 +6,8 @@ import torch import torch.distributed as dist +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP @@ -13,9 +15,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): try: # Dynamically import the module and MLP class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join( - [part.capitalize() for part in model_type.split("_")] - ) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"]) mlp_cls = getattr(module, f"{model_cls_prefix}MLP") diff --git a/src/axolotl/utils/callbacks/models.py b/src/axolotl/utils/callbacks/models.py new file mode 100644 index 0000000000..5a20d70d9c --- /dev/null +++ b/src/axolotl/utils/callbacks/models.py @@ -0,0 +1,23 @@ +"""Helper functions for model classes""" + +from typing import Tuple + +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + + +def get_causal_lm_model_cls_prefix(model_type: str) -> Tuple[str, str]: + if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + causal_lm_cls = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + causal_lm_cls_prefix = causal_lm_cls + for suffix in [ + "ForCausalLM", + "ForConditionalGeneration", + "LMHeadModel", + "GenerationDecoder", + ]: + causal_lm_cls_prefix = causal_lm_cls_prefix.replace(suffix, "") + return causal_lm_cls_prefix, causal_lm_cls + causal_lm_cls_prefix = "".join( + [part.capitalize() for part in model_type.split("_")] + ) + return causal_lm_cls_prefix, f"{causal_lm_cls_prefix}ForCausalLM" From 0a1eb8e66bb42dc4a36c44f96ba3a6b4f0abd6f0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 15 Jul 2025 22:32:48 -0400 Subject: [PATCH 08/10] revert inadvertent change to flash attn verison --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 9606fd0afa..df9a231544 100644 --- a/setup.py +++ b/setup.py @@ -114,9 +114,9 @@ def get_package_version(): extras_require = { - "flash-attn": ["flash-attn>=2.7.4.post1"], + "flash-attn": ["flash-attn==2.8.0.post2"], "ring-flash-attn": [ - "flash-attn>=2.7.4.post1", + "flash-attn==2.8.0.post2", "ring-flash-attn>=0.1.5", "yunchang==0.6.0", ], From aaf96eec838b1e883a7a050f0964f937a620e541 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 15 Jul 2025 22:37:24 -0400 Subject: [PATCH 09/10] remove un-needed pylint annotations --- src/axolotl/integrations/liger/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index bcdb983d4c..86d56be802 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -18,8 +18,8 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training. It is designed to be performant, correct, and light-weight. """ -from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 -from .plugin import LigerPlugin # pylint: disable=unused-import. # noqa: F401 +from .args import LigerArgs +from .plugin import LigerPlugin __all__ = [ "LigerArgs", From 00571f9fd6ef714754599e2126f57aca4af9fc47 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 15 Jul 2025 22:38:16 -0400 Subject: [PATCH 10/10] fix import --- src/axolotl/integrations/cut_cross_entropy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 699ca4033d..6c47097b73 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -25,9 +25,9 @@ from axolotl.integrations.base import BasePlugin from axolotl.utils import get_pytorch_version +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix from axolotl.utils.logging import get_logger -from ...utils.callbacks.models import get_causal_lm_model_cls_prefix from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 LOG = get_logger(__name__)