Skip to content
Merged
48 changes: 48 additions & 0 deletions src/axolotl/integrations/cut_cross_entropy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
from Apple's ML team.
"""
import importlib
from functools import partial

import torch

from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger

from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
Expand Down Expand Up @@ -84,6 +86,7 @@ def pre_model_load(self, cfg):
"""Apply cut cross entropy before model loading if enabled."""
if cfg.cut_cross_entropy:
self._check_requirements()
self.patch_llama_like(cfg.model_config_type)

from cut_cross_entropy.transformers.patch import cce_patch

Expand All @@ -93,3 +96,48 @@ def pre_model_load(self, cfg):

# The patch checks model_type internally
cce_patch(cfg.model_config_type)

def patch_llama_like(
self,
model_type: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
"""
from cut_cross_entropy.transformers.patch import PATCH_FNS

def patch_generic(
maybe_model, patch_options, model_type: str
): # pylint: disable=unused-argument
Comment thread
winglian marked this conversation as resolved.
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward

try:
# Dynamically import the module and CausalLM class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(
module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
)
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")

cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
patch_options
)

model_cls.forward = cce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

if model_type not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
4 changes: 3 additions & 1 deletion src/axolotl/integrations/kd/kernels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class TransformersKwargs(FlashAttentionKwargs, LossKwargs):
TransformersKwargs,
)

from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix


def kldiv_forward_llama_like(
self,
Expand Down Expand Up @@ -97,7 +99,7 @@ def kldiv_forward_llama_like(
def apply_kernel(model_type):
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = kldiv_forward_llama_like
172 changes: 6 additions & 166 deletions src/axolotl/integrations/liger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,170 +18,10 @@
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import sys
from .args import LigerArgs
from .plugin import LigerPlugin

from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger

from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable

LOG = get_logger(__name__)


class LigerPlugin(BasePlugin):
"""
Plugin for LIGER integraton with Axolotl.
"""

def get_input_args(self):
return "axolotl.integrations.liger.LigerArgs"

def pre_model_load(self, cfg):
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy

patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
)
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP

if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)

if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs["fused_linear_cross_entropy"] = (
cfg.liger_fused_linear_cross_entropy
)
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba

from .models.jamba import lce_forward as jamba_lce_forward

if cfg.liger_rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM

with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
)
modeling_mod = sys.modules[model.__class__.__module__]

from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward

if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,
)

apply_liger_kernel_to_llama4(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)

apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)

apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite

apply_liger_kernel_to_granite(
rope=cfg.liger_rope,
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
else:
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)
__all__ = [
"LigerArgs",
"LigerPlugin",
]
Loading
Loading