Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_5_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
Expand Down Expand Up @@ -133,6 +134,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_qwen2_vl",
"apply_liger_kernel_to_qwen3",
"apply_liger_kernel_to_qwen3_moe",
"apply_liger_kernel_to_qwen3_5_moe",
"apply_liger_kernel_to_qwen3_next",
"apply_liger_kernel_to_qwen3_vl",
"apply_liger_kernel_to_qwen3_vl_moe",
Expand Down Expand Up @@ -213,6 +215,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_qwen2_vl",
"apply_liger_kernel_to_qwen3",
"apply_liger_kernel_to_qwen3_moe",
"apply_liger_kernel_to_qwen3_5_moe",
"apply_liger_kernel_to_qwen3_next",
"apply_liger_kernel_to_qwen3_vl",
"apply_liger_kernel_to_qwen3_vl_moe",
Expand Down
155 changes: 155 additions & 0 deletions src/liger_kernel/transformers/model/qwen3_5_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from typing import Union

import torch

from transformers.modeling_outputs import MoeModelOutputWithPast

if TYPE_CHECKING:
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import load_balancing_loss_func

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast


def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> LigerMoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns:

Example:

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct")

>>> prompt = "Give me a short introduction to large language model."
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None

if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.vocab_size,
**kwargs,
)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device

if not return_dict:
output = (logits,) + outputs[1:]
output = ((aux_loss,) + output) if aux_loss is not None else output
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output

return LigerMoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
8 changes: 7 additions & 1 deletion src/liger_kernel/transformers/model/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ def lce_forward(
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.vocab_size,
**kwargs,
)

aux_loss = None
if output_router_logits:
Expand Down
97 changes: 94 additions & 3 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,13 +2652,17 @@ def apply_liger_kernel_to_qwen3_next(
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
)

_patch_rms_norm_module_for_qwen3_next = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)

if rms_norm:
_patch_rms_norm_module(base_model.norm)
_patch_rms_norm_module_for_qwen3_next(base_model.norm)

for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
_patch_rms_norm_module_for_qwen3_next(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_qwen3_next(decoder_layer.post_attention_layernorm)

# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
if swiglu:
Expand All @@ -2675,6 +2679,91 @@ def apply_liger_kernel_to_qwen3_next(
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)


def apply_liger_kernel_to_qwen3_5_moe(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.qwen3_5_moe import modeling_qwen3_5_moe
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextModel

from liger_kernel.transformers.model.qwen3_5_moe import lce_forward as qwen3_5_moe_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP

if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5Moe models.")
if rms_norm:
modeling_qwen3_5_moe.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3Next
if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
if isinstance(model, Qwen3_5MoeForCausalLM):
model.forward = MethodType(qwen3_5_moe_lce_forward, model)
else:
raise TypeError(
f" fused_linear_cross_entropy is only applicable on Qwen3_5MoeForCausalLM. Got: {type(model)}"
)
else:
modeling_qwen3_5_moe.Qwen3_5MoeForCausalLM.forward = qwen3_5_moe_lce_forward
if swiglu:
modeling_qwen3_5_moe.Qwen3_5MoeExperts = LigerExperts

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, (Qwen3_5MoeForCausalLM, Qwen3_5MoeTextModel)):
base_model: Qwen3_5MoeTextModel = getattr(model, model.base_model_prefix, model)
else:
raise TypeError(
f"Unsupported qwen3_5_moe model type. `model` must be `Qwen3_5MoeForCausalLM`, `Qwen3_5MoeTextModel`. Got: {type(model)}"
)

_patch_rms_norm_module_for_qwen3_5_moe = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)

if rms_norm:
_patch_rms_norm_module_for_qwen3_5_moe(base_model.norm)

for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module_for_qwen3_5_moe(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_qwen3_5_moe(decoder_layer.post_attention_layernorm)

if swiglu:
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
experts = getattr(decoder_layer.mlp, "experts", None)
if experts is not None:
_patch_swiglu_module(experts, LigerExperts)


def apply_liger_kernel_to_hunyuan_v1_dense(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -2906,6 +2995,8 @@ def __init__(self, hidden_size, eps=1e-6, **kwargs):
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
"qwen3_next": apply_liger_kernel_to_qwen3_next,
"qwen3_5_moe": apply_liger_kernel_to_qwen3_5_moe,
"qwen3_5_moe_text": apply_liger_kernel_to_qwen3_5_moe,
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
Expand Down
Loading
Loading