Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ loss.backward()
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |


## Low-level APIs
Expand Down
6 changes: 6 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
Expand Down Expand Up @@ -128,6 +130,8 @@ def __getattr__(name: str):
"apply_liger_kernel_to_qwen3_vl_moe",
"apply_liger_kernel_to_smollm3",
"apply_liger_kernel_to_smolvlm",
"apply_liger_kernel_to_hunyuan_v1_dense",
"apply_liger_kernel_to_hunyuan_v1_moe",
}

if name in monkey_patch_symbols:
Expand Down Expand Up @@ -202,5 +206,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_qwen3_vl_moe",
"apply_liger_kernel_to_smollm3",
"apply_liger_kernel_to_smolvlm",
"apply_liger_kernel_to_hunyuan_v1_dense",
"apply_liger_kernel_to_hunyuan_v1_moe",
]
)
134 changes: 134 additions & 0 deletions src/liger_kernel/transformers/model/hunyuan_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import List
from typing import Optional
from typing import Union

import torch

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast


def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> LigerCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns:

Example:

```python
>>> from transformers import AutoTokenizer, HunYuanDenseV1ForCausalLM

>>> model = HunYuanDenseV1ForCausalLM.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None

if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")

if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)

# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy = unpack_cross_entropy_result(result)

else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (logits,) + outputs[1:]
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
return output

# Return custom output class with accuracy field
return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
)
119 changes: 119 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2558,6 +2558,123 @@ def apply_liger_kernel_to_qwen3_next(
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)


def apply_liger_kernel_to_hunyuan_v1_dense(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model

from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP

if rope:
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb

if rms_norm:
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm

if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(hunyuan_v1_lce_forward, model)
else:
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward

if swiglu:
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

# get the base model from the model instance
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)

if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


def apply_liger_kernel_to_hunyuan_v1_moe(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model

from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP

if rope:
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb

if rms_norm:
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm

if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
else:
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward

if swiglu:
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

# get the base model from the model instance
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)

if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
for mlp_expert in decoder_layer.mlp.experts:
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
Expand Down Expand Up @@ -2595,6 +2712,8 @@ def apply_liger_kernel_to_qwen3_next(
"paligemma": apply_liger_kernel_to_paligemma,
"falcon_h1": apply_liger_kernel_to_falcon_h1,
"smolvlm": apply_liger_kernel_to_smolvlm,
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
}


Expand Down
17 changes: 17 additions & 0 deletions src/liger_kernel/transformers/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,20 @@ def __init__(self, config, intermediate_size=None):

def forward(self, x):
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))


class LigerHunyuanV1SwiGLUMLP(nn.Module):
def __init__(self, config, layer_idx=None, is_shared_mlp=False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.layer_idx = layer_idx
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")

def forward(self, x):
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
Loading
Loading