From 93dc123d24a2d25ecfaf5dd7741f0d6eaf1fc44c Mon Sep 17 00:00:00 2001 From: Michael Royzen Date: Thu, 26 Feb 2026 13:53:37 -0500 Subject: [PATCH 1/7] Add support for Qwen3.5 MoE --- src/liger_kernel/transformers/__init__.py | 3 + .../transformers/model/qwen3_5_moe.py | 149 ++++++++++++++++++ src/liger_kernel/transformers/monkey_patch.py | 83 ++++++++++ test/convergence/fp32/test_mini_models.py | 46 ++++++ test/transformers/test_monkey_patch.py | 76 +++++++++ test/utils.py | 12 ++ 6 files changed, 369 insertions(+) create mode 100644 src/liger_kernel/transformers/model/qwen3_5_moe.py diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 922ed61fa..67e5061ed 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -64,6 +64,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_5_moe # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401 @@ -133,6 +134,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_qwen2_vl", "apply_liger_kernel_to_qwen3", "apply_liger_kernel_to_qwen3_moe", + "apply_liger_kernel_to_qwen3_5_moe", "apply_liger_kernel_to_qwen3_next", "apply_liger_kernel_to_qwen3_vl", "apply_liger_kernel_to_qwen3_vl_moe", @@ -213,6 +215,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_qwen2_vl", "apply_liger_kernel_to_qwen3", "apply_liger_kernel_to_qwen3_moe", + "apply_liger_kernel_to_qwen3_5_moe", "apply_liger_kernel_to_qwen3_next", "apply_liger_kernel_to_qwen3_vl", "apply_liger_kernel_to_qwen3_vl_moe", diff --git a/src/liger_kernel/transformers/model/qwen3_5_moe.py b/src/liger_kernel/transformers/model/qwen3_5_moe.py new file mode 100644 index 000000000..c367cd339 --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen3_5_moe.py @@ -0,0 +1,149 @@ +from typing import TYPE_CHECKING +from typing import List +from typing import Optional +from typing import Union + +import torch + +from transformers.modeling_outputs import MoeModelOutputWithPast + +if TYPE_CHECKING: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import load_balancing_loss_func + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> LigerMoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct") + + >>> prompt = "Give me a short introduction to large language model." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: # if in inference model materialize logits + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((aux_loss,) + output) if aux_loss is not None else output + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + return LigerMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index f7936cf8e..83f7f5cb6 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2675,6 +2675,87 @@ def apply_liger_kernel_to_qwen3_next( _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP) +def apply_liger_kernel_to_qwen3_5_moe( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen3_5_moe import modeling_qwen3_5_moe + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextModel + + from liger_kernel.transformers.model.qwen3_5_moe import lce_forward as qwen3_5_moe_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next + from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP + + if rope: + raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5Moe models.") + if rms_norm: + modeling_qwen3_5_moe.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3Next + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + if isinstance(model, Qwen3_5MoeForCausalLM): + model.forward = MethodType(qwen3_5_moe_lce_forward, model) + else: + raise TypeError( + f" fused_linear_cross_entropy is only applicable on Qwen3_5MoeForCausalLM. Got: {type(model)}" + ) + else: + modeling_qwen3_5_moe.Qwen3_5MoeForCausalLM.forward = qwen3_5_moe_lce_forward + if swiglu: + modeling_qwen3_5_moe.Qwen3_5MoeExperts = LigerExperts + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, (Qwen3_5MoeForCausalLM, Qwen3_5MoeTextModel)): + base_model: Qwen3_5MoeTextModel = getattr(model, model.base_model_prefix, model) + else: + raise TypeError( + f"Unsupported qwen3_5_moe model type. `model` must be `Qwen3_5MoeForCausalLM`, `Qwen3_5MoeTextModel`. Got: {type(model)}" + ) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + if swiglu: + _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP) + experts = getattr(decoder_layer.mlp, "experts", None) + if experts is not None: + _patch_swiglu_module(experts, LigerExperts) + + def apply_liger_kernel_to_hunyuan_v1_dense( rope: bool = True, cross_entropy: bool = False, @@ -2906,6 +2987,8 @@ def __init__(self, hidden_size, eps=1e-6, **kwargs): "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl, "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl, "qwen3_next": apply_liger_kernel_to_qwen3_next, + "qwen3_5_moe": apply_liger_kernel_to_qwen3_5_moe, + "qwen3_5_moe_text": apply_liger_kernel_to_qwen3_5_moe, "qwen3_vl": apply_liger_kernel_to_qwen3_vl, "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl, "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe, diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index 13f69d013..a52a516b4 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -51,6 +51,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen3 from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe @@ -89,6 +90,7 @@ from test.utils import revert_liger_kernel_to_qwen2_vl from test.utils import revert_liger_kernel_to_qwen3 from test.utils import revert_liger_kernel_to_qwen3_moe +from test.utils import revert_liger_kernel_to_qwen3_5_moe from test.utils import revert_liger_kernel_to_qwen3_next from test.utils import revert_liger_kernel_to_qwen3_vl from test.utils import revert_liger_kernel_to_qwen3_vl_moe @@ -292,6 +294,14 @@ except ImportError: QWEN3NEXT_AVAILABLE = False +try: + from transformers.models.qwen3_5_moe.modular_qwen3_5_moe import Qwen3_5MoeTextConfig + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + + QWEN3_5_MOE_AVAILABLE = True +except ImportError: + QWEN3_5_MOE_AVAILABLE = False + try: from transformers.models.hunyuan_v1_dense.configuration_hunyuan_v1_dense import HunYuanDenseV1Config from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1ForCausalLM @@ -1379,6 +1389,42 @@ ), ) +if QWEN3_5_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5_moe, + model_class=Qwen3_5MoeForCausalLM, + mini_model_config=Qwen3_5MoeTextConfig( + vocab_size=32000, + hidden_size=896, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + head_dim=128, + linear_conv_kernel_dim=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=8, + linear_num_value_heads=8, + moe_intermediate_size=768, + shared_expert_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + # config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype()) + dtype=torch.float32, + ), + ) + if HUNYUAN_V1_AVAILABLE: MINI_MODEL_SETUPS["mini_hunyuan_v1"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_dense, diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index e87a38481..93bf3c988 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -207,6 +207,15 @@ def is_qwen3_next_available(): return False +def is_qwen3_5_moe_available(): + try: + import transformers.models.qwen3_5_moe # noqa: F401 + + return True + except ImportError: + return False + + def is_pixtral_available(): try: import transformers.models.pixtral # noqa: F401 @@ -2899,6 +2908,73 @@ def test_apply_liger_kernel_to_instance_for_qwen3_next(): pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") +@pytest.mark.skipif(not is_qwen3_5_moe_available(), reason="qwen3_5_moe module not available") +def test_apply_liger_kernel_to_instance_for_qwen3_5_moe(): + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"): + from liger_kernel.transformers.model.qwen3_5_moe import lce_forward as qwen3_5_moe_lce_forward + + # Instantiate a dummy model + config = transformers.models.qwen3_5_moe.modular_qwen3_5_moe.Qwen3_5MoeTextConfig( + dtype=torch.bfloat16, + rms_norm_eps=1e-5, + hidden_size=32, + moe_intermediate_size=16, + shared_expert_intermediate_size=16, + hidden_act="silu", + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=16, + linear_conv_kernel_dim=4, + linear_key_head_dim=16, + linear_value_head_dim=16, + linear_num_key_heads=2, + linear_num_value_heads=2, + num_experts=2, + num_experts_per_tok=1, + ) + dummy_model_instance = AutoModelForCausalLM.from_config(config) + + # Check that model instance variables are not yet patched with Liger modules + assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(qwen3_5_moe_lce_forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.model.layers: + if IS_TRANSFORMERS_V5_OR_LATER: + assert inspect.getsource(layer.mlp.experts.forward) != inspect.getsource(LigerExperts.forward) + else: + for expert in layer.mlp.experts: + assert inspect.getsource(expert.forward) != inspect.getsource(LigerQwen3MoeSwiGLUMLP.forward) + assert inspect.getsource(layer.mlp.shared_expert.forward) != inspect.getsource( + LigerQwen3MoeSwiGLUMLP.forward + ) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + + # Test applying kernels to the model instance + _apply_liger_kernel_to_instance(model=dummy_model_instance) + + # Check that the model's instance variables were correctly patched with Liger modules + assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(qwen3_5_moe_lce_forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.model.layers: + if IS_TRANSFORMERS_V5_OR_LATER: + assert inspect.getsource(layer.mlp.experts.forward) == inspect.getsource(LigerExperts.forward) + else: + for expert in layer.mlp.experts: + assert inspect.getsource(expert.forward) == inspect.getsource(LigerQwen3MoeSwiGLUMLP.forward) + assert inspect.getsource(layer.mlp.shared_expert.forward) == inspect.getsource( + LigerQwen3MoeSwiGLUMLP.forward + ) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + + try: + print(dummy_model_instance) + except Exception as e: + pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") + + @pytest.mark.skipif(not is_hunyuan_v1_available(), reason="hunyuan_v1 module not available") def test_apply_liger_kernel_to_instance_for_hunyuan_v1_moe(): # Ensure any monkey patching is cleaned up for subsequent tests diff --git a/test/utils.py b/test/utils.py index d678830d0..b5cb3ecc1 100644 --- a/test/utils.py +++ b/test/utils.py @@ -760,6 +760,18 @@ def revert_liger_kernel_to_qwen3_next(model_config: MiniModelConfig): print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_qwen3_5_moe(model_config: MiniModelConfig): + """ + Revert all Liger kernel patches applied to Qwen3.5 MoE. + """ + + from transformers.models.qwen3_5_moe import modeling_qwen3_5_moe + + importlib.reload(modeling_qwen3_5_moe) + model_config.model_class = modeling_qwen3_5_moe.Qwen3_5MoeForCausalLM + print("Liger kernel patches have been reverted.") + + def revert_liger_kernel_to_hunyuan_v1(model_config: MiniModelConfig): """ Revert all Liger kernel patches applied to Hunyuanv1. From 2d0f1bc68e8446510e7352093482b87bfcea115a Mon Sep 17 00:00:00 2001 From: Michael Royzen Date: Thu, 26 Feb 2026 14:10:47 -0500 Subject: [PATCH 2/7] Both Qwen3.5 MoE and Qwen3-Next should use Gemma-style RMSNorm instead of Llama-style --- src/liger_kernel/transformers/monkey_patch.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 83f7f5cb6..9baefcf47 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2652,13 +2652,17 @@ def apply_liger_kernel_to_qwen3_next( f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}" ) + _patch_rms_norm_module_for_qwen3_next = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False + ) + if rms_norm: - _patch_rms_norm_module(base_model.norm) + _patch_rms_norm_module_for_qwen3_next(base_model.norm) for decoder_layer in base_model.layers: if rms_norm: - _patch_rms_norm_module(decoder_layer.input_layernorm) - _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module_for_qwen3_next(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_qwen3_next(decoder_layer.post_attention_layernorm) # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP if swiglu: @@ -2741,13 +2745,17 @@ def apply_liger_kernel_to_qwen3_5_moe( f"Unsupported qwen3_5_moe model type. `model` must be `Qwen3_5MoeForCausalLM`, `Qwen3_5MoeTextModel`. Got: {type(model)}" ) + _patch_rms_norm_module_for_qwen3_5_moe = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False + ) + if rms_norm: - _patch_rms_norm_module(base_model.norm) + _patch_rms_norm_module_for_qwen3_5_moe(base_model.norm) for decoder_layer in base_model.layers: if rms_norm: - _patch_rms_norm_module(decoder_layer.input_layernorm) - _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module_for_qwen3_5_moe(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_qwen3_5_moe(decoder_layer.post_attention_layernorm) if swiglu: _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP) From e2b0666aa884d7587ff6dae9bc766406df6edc29 Mon Sep 17 00:00:00 2001 From: Michael Royzen Date: Thu, 26 Feb 2026 14:25:48 -0500 Subject: [PATCH 3/7] Convergence test fixes --- test/convergence/fp32/test_mini_models.py | 24 ++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index a52a516b4..280cf1b8e 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -1542,7 +1542,7 @@ def run_mini_model( "rms_norm": True, } - if "glm4" in model_name or "qwen3_next" in model_name: + if "glm4" in model_name or "qwen3_next" in model_name or "qwen3_5_moe" in model_name: kwargs["rope"] = False model_supports_layer_norm = "qwen2_vl" in model_name @@ -1983,6 +1983,28 @@ def run_mini_model( ), ], ), + pytest.param( + "mini_qwen3_5_moe", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3_5_MOE_AVAILABLE, + reason="Qwen3_5Moe not available in this version of transformers", + ), + pytest.mark.skip( + reason="flash-linear-attention's ChunkGatedDeltaRuleFunction does not support float32.\n" + + " Torch's implementation takes too long" + ), + ], + ), pytest.param( "mini_hunyuan_v1", 32, From 849b2f853ed75fcd0ee6a06024709d05da62e020 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 26 Feb 2026 19:45:15 +0000 Subject: [PATCH 4/7] Fix test imports --- test/convergence/fp32/test_mini_models.py | 6 +----- test/transformers/test_monkey_patch.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index 280cf1b8e..f6aea2fd0 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -295,7 +295,7 @@ QWEN3NEXT_AVAILABLE = False try: - from transformers.models.qwen3_5_moe.modular_qwen3_5_moe import Qwen3_5MoeTextConfig + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextConfig from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM QWEN3_5_MOE_AVAILABLE = True @@ -1999,10 +1999,6 @@ def run_mini_model( not QWEN3_5_MOE_AVAILABLE, reason="Qwen3_5Moe not available in this version of transformers", ), - pytest.mark.skip( - reason="flash-linear-attention's ChunkGatedDeltaRuleFunction does not support float32.\n" - + " Torch's implementation takes too long" - ), ], ), pytest.param( diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 93bf3c988..1fa4eea75 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -2915,7 +2915,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_5_moe(): from liger_kernel.transformers.model.qwen3_5_moe import lce_forward as qwen3_5_moe_lce_forward # Instantiate a dummy model - config = transformers.models.qwen3_5_moe.modular_qwen3_5_moe.Qwen3_5MoeTextConfig( + config = transformers.models.qwen3_5_moe.configuration_qwen3_5_moe.Qwen3_5MoeTextConfig( dtype=torch.bfloat16, rms_norm_eps=1e-5, hidden_size=32, From a017ed4ad18dec8437d2827570bfd00c6569ef2e Mon Sep 17 00:00:00 2001 From: Michael Royzen Date: Fri, 27 Feb 2026 10:40:17 -0500 Subject: [PATCH 5/7] Add shift_labels to loss_function calls --- src/liger_kernel/transformers/model/qwen3_5_moe.py | 8 +++++++- src/liger_kernel/transformers/model/qwen3_next.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/transformers/model/qwen3_5_moe.py b/src/liger_kernel/transformers/model/qwen3_5_moe.py index c367cd339..2eb7fe072 100644 --- a/src/liger_kernel/transformers/model/qwen3_5_moe.py +++ b/src/liger_kernel/transformers/model/qwen3_5_moe.py @@ -115,7 +115,13 @@ def lce_forward( else: # if in inference model materialize logits logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) aux_loss = None if output_router_logits: diff --git a/src/liger_kernel/transformers/model/qwen3_next.py b/src/liger_kernel/transformers/model/qwen3_next.py index e72fda4ab..5f6dd0062 100644 --- a/src/liger_kernel/transformers/model/qwen3_next.py +++ b/src/liger_kernel/transformers/model/qwen3_next.py @@ -115,7 +115,13 @@ def lce_forward( else: # if in inference model materialize logits logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) aux_loss = None if output_router_logits: From 4ef98080d6764188a777ab348c09055765806bde Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 27 Feb 2026 18:43:07 +0000 Subject: [PATCH 6/7] Match fp32 skip behavior for Qwen3.5 MoE (as with Qwen3-Next) and add bf16 test matching Qwen3 MoE tolerances --- test/convergence/bf16/test_mini_models.py | 66 ++++++++++++++++++++++- test/convergence/fp32/test_mini_models.py | 4 ++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index 63b560a7e..61364fb7c 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -51,6 +51,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen3 from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe @@ -89,6 +90,7 @@ from test.utils import revert_liger_kernel_to_qwen2_vl from test.utils import revert_liger_kernel_to_qwen3 from test.utils import revert_liger_kernel_to_qwen3_moe +from test.utils import revert_liger_kernel_to_qwen3_5_moe from test.utils import revert_liger_kernel_to_qwen3_next from test.utils import revert_liger_kernel_to_qwen3_vl from test.utils import revert_liger_kernel_to_qwen3_vl_moe @@ -294,6 +296,14 @@ except ImportError: QWEN3NEXT_AVAILABLE = False +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextConfig + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + + QWEN3_5_MOE_AVAILABLE = True +except ImportError: + QWEN3_5_MOE_AVAILABLE = False + try: from transformers.models.hunyuan_v1_dense.configuration_hunyuan_v1_dense import HunYuanDenseV1Config from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1ForCausalLM @@ -1387,6 +1397,41 @@ ), ) +if QWEN3_5_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5_moe, + model_class=Qwen3_5MoeForCausalLM, + mini_model_config=Qwen3_5MoeTextConfig( + vocab_size=32000, + hidden_size=896, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + head_dim=128, + linear_conv_kernel_dim=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=8, + linear_num_value_heads=8, + moe_intermediate_size=768, + shared_expert_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + # config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype()) + dtype=torch.bfloat16, + ), + ) if HUNYUAN_V1_AVAILABLE: MINI_MODEL_SETUPS["mini_hunyuan_v1"] = MiniModelConfig( @@ -1509,7 +1554,7 @@ def run_mini_model( "rms_norm": True, } - if "glm4" in model_name or "qwen3_next" in model_name: + if "glm4" in model_name or "qwen3_next" in model_name or "qwen3_5_moe" in model_name: kwargs["rope"] = False model_supports_layer_norm = "qwen2_vl" in model_name @@ -2093,6 +2138,25 @@ def run_mini_model( ), ], ), + pytest.param( + "mini_qwen3_5_moe", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 2e-1, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_5_MOE_AVAILABLE, + reason="Qwen3_5Moe not available in this version of transformers", + ), + ], + ), pytest.param( "mini_hunyuan_v1", 32, diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index f6aea2fd0..f937c327e 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -1999,6 +1999,10 @@ def run_mini_model( not QWEN3_5_MOE_AVAILABLE, reason="Qwen3_5Moe not available in this version of transformers", ), + pytest.mark.skip( + reason="flash-linear-attention's ChunkGatedDeltaRuleFunction does not support float32.\n" + + " Torch's implementation takes too long" + ), ], ), pytest.param( From 7a830920c9b125bf3593d4b551457706a81cebd4 Mon Sep 17 00:00:00 2001 From: Michael Royzen Date: Mon, 2 Mar 2026 13:06:33 -0500 Subject: [PATCH 7/7] Rebase and lint --- src/liger_kernel/transformers/__init__.py | 2 +- test/convergence/bf16/test_mini_models.py | 6 +++--- test/convergence/fp32/test_mini_models.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 67e5061ed..f7ab2ace9 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -63,8 +63,8 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401 - from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_5_moe # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401 diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index 61364fb7c..56d13f740 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -50,8 +50,8 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen3 -from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe @@ -89,8 +89,8 @@ from test.utils import revert_liger_kernel_to_qwen2_5_vl from test.utils import revert_liger_kernel_to_qwen2_vl from test.utils import revert_liger_kernel_to_qwen3 -from test.utils import revert_liger_kernel_to_qwen3_moe from test.utils import revert_liger_kernel_to_qwen3_5_moe +from test.utils import revert_liger_kernel_to_qwen3_moe from test.utils import revert_liger_kernel_to_qwen3_next from test.utils import revert_liger_kernel_to_qwen3_vl from test.utils import revert_liger_kernel_to_qwen3_vl_moe @@ -297,8 +297,8 @@ QWEN3NEXT_AVAILABLE = False try: - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextConfig from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextConfig QWEN3_5_MOE_AVAILABLE = True except ImportError: diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index f937c327e..b37b86bbf 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -50,8 +50,8 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen3 -from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe @@ -89,8 +89,8 @@ from test.utils import revert_liger_kernel_to_qwen2_5_vl from test.utils import revert_liger_kernel_to_qwen2_vl from test.utils import revert_liger_kernel_to_qwen3 -from test.utils import revert_liger_kernel_to_qwen3_moe from test.utils import revert_liger_kernel_to_qwen3_5_moe +from test.utils import revert_liger_kernel_to_qwen3_moe from test.utils import revert_liger_kernel_to_qwen3_next from test.utils import revert_liger_kernel_to_qwen3_vl from test.utils import revert_liger_kernel_to_qwen3_vl_moe @@ -295,8 +295,8 @@ QWEN3NEXT_AVAILABLE = False try: - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextConfig from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextConfig QWEN3_5_MOE_AVAILABLE = True except ImportError: