diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 2622cea34..8b206486a 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -430,7 +430,7 @@ def apply_liger_kernel_to_llava( f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n" f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}" ) - text_kwargs["model"] = model.language_model + text_kwargs["model"] = model.model.language_model text_liger_fn(**text_kwargs) elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: logger.warning(f"{text_model_name} is not supported by Liger kernel.") @@ -445,7 +445,7 @@ def apply_liger_kernel_to_llava( f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n" f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}" ) - vision_kwargs["model"] = model.vision_tower + vision_kwargs["model"] = model.model.vision_tower vision_liger_fn(**vision_kwargs) elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: logger.warning(f"{vision_model_name} is not supported by Liger kernel.") @@ -615,8 +615,8 @@ def apply_liger_kernel_to_mllama( # instance variables that reference already-instantiated modules if isinstance(model, MllamaForConditionalGeneration): - language_model: MllamaForCausalLM = model.language_model - vision_model: MllamaVisionModel = model.vision_model + language_model: MllamaForCausalLM = model.model.language_model + vision_model: MllamaVisionModel = model.model.vision_model if isinstance(language_model, MllamaForCausalLM): text_model: MllamaTextModel = language_model.model else: @@ -1118,8 +1118,8 @@ def apply_liger_kernel_to_gemma3( # instance variables that reference already-instantiated modules if isinstance(model, Gemma3ForConditionalGeneration): - if isinstance(model.vision_tower, SiglipVisionModel): - vision_tower = model.vision_tower + if isinstance(model.model.vision_tower, SiglipVisionModel): + vision_tower = model.model.vision_tower _patch_layer_norm_module(vision_tower.vision_model.post_layernorm) @@ -1132,7 +1132,7 @@ def apply_liger_kernel_to_gemma3( raise TypeError("The vision tower must be SiglipVisionModel") if rms_norm: - _patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm) + _patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm) apply_liger_kernel_to_gemma3_text( rope=rope, @@ -1140,7 +1140,7 @@ def apply_liger_kernel_to_gemma3( fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu, - model=model.language_model, + model=model.model.language_model, ) else: @@ -1228,7 +1228,7 @@ def apply_liger_kernel_to_paligemma( if not isinstance(model, PaliGemmaForConditionalGeneration): raise TypeError("model have to be of type PaliGemmaForConditionalGeneration") - vision_tower: SiglipVisionModel = model.vision_tower + vision_tower: SiglipVisionModel = model.model.vision_tower _patch_layer_norm_module(vision_tower.vision_model.post_layernorm) @@ -1238,7 +1238,7 @@ def apply_liger_kernel_to_paligemma( _patch_layer_norm_module(layer.layer_norm1) _patch_layer_norm_module(layer.layer_norm2) - language_model = model.language_model + language_model = model.model.language_model if isinstance(language_model, (GemmaForCausalLM, GemmaModel)): apply_liger_kernel_to_gemma( @@ -1593,11 +1593,10 @@ def apply_liger_kernel_to_qwen2_vl( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - - if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)): - # Note: language_model and visual properties can be accessed throught conditional class for BC. - # Not sure if it is subject to changes in the future. - # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698 + if isinstance(model, Qwen2VLForConditionalGeneration): + text_model: Qwen2VLTextModel = model.model.language_model + vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual + elif isinstance(model, Qwen2VLModel): text_model: Qwen2VLTextModel = model.language_model vision_model: Qwen2VisionTransformerPretrainedModel = model.visual elif isinstance(model, Qwen2VLTextModel): @@ -1684,11 +1683,10 @@ def apply_liger_kernel_to_qwen2_5_vl( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - - if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)): - # Note: language_model and visual properties can be accessed throught conditional class for BC. - # Not sure if it is subject to changes in the future. - # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823 + if isinstance(model, Qwen2_5_VLForConditionalGeneration): + text_model: Qwen2_5_VLTextModel = model.model.language_model + vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual + elif isinstance(model, Qwen2_5_VLModel): text_model: Qwen2_5_VLTextModel = model.language_model vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual elif isinstance(model, Qwen2_5_VLTextModel): @@ -1702,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl( if vision_model is not None: # Patch Qwen2_5_VisionTransformerPretrainedModel - for vision_block in model.visual.blocks: + for vision_block in vision_model.blocks: if rms_norm: _patch_rms_norm_module(vision_block.norm1) _patch_rms_norm_module(vision_block.norm2) @@ -1771,7 +1769,9 @@ def apply_liger_kernel_to_qwen3_vl( modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward if model is not None and rms_norm: - if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)): + if isinstance(model, Qwen3VLForConditionalGeneration): + text_model: Qwen3VLTextModel = model.model.language_model + elif isinstance(model, Qwen3VLModel): text_model: Qwen3VLTextModel = model.language_model elif isinstance(model, Qwen3VLTextModel): text_model = model @@ -1846,7 +1846,9 @@ def apply_liger_kernel_to_qwen3_vl_moe( modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward if model is not None and rms_norm: - if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)): + if isinstance(model, Qwen3VLMoeForConditionalGeneration): + text_model: Qwen3VLMoeTextModel = model.model.language_model + elif isinstance(model, Qwen3VLMoeModel): text_model: Qwen3VLMoeTextModel = model.language_model elif isinstance(model, Qwen3VLMoeTextModel): text_model = model @@ -2191,10 +2193,10 @@ def apply_liger_kernel_to_glm4v( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)): - # Note: language_model and visual properties can be accessed throught conditional class for BC. - # Not sure if it is subject to changes in the future. - # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305 + if isinstance(model, Glm4vForConditionalGeneration): + text_model: Glm4vTextModel = model.model.language_model + vision_model: Glm4vVisionModel = model.model.visual + elif isinstance(model, Glm4vModel): text_model: Glm4vTextModel = model.language_model vision_model: Glm4vVisionModel = model.visual elif isinstance(model, Glm4vTextModel): @@ -2281,10 +2283,11 @@ def apply_liger_kernel_to_glm4v_moe( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)): - # Note: language_model and visual properties can be accessed throught conditional class for BC. - # Not sure if it is subject to changes in the future. - # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337 + if isinstance(model, Glm4vMoeForConditionalGeneration): + text_model: Glm4vMoeTextModel = model.model.language_model + vision_model: Glm4vMoeVisionModel = model.model.visual + Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE + elif isinstance(model, Glm4vMoeModel): text_model: Glm4vMoeTextModel = model.language_model vision_model: Glm4vMoeVisionModel = model.visual Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE @@ -2387,8 +2390,10 @@ def apply_liger_kernel_to_internvl( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)): - # NOTE: language_model and visual properties can be accessed throught conditional class. + if isinstance(model, InternVLForConditionalGeneration): + text_model = model.model.language_model + vision_model: InternVLVisionModel = model.model.vision_tower + elif isinstance(model, InternVLModel): text_model = model.language_model vision_model: InternVLVisionModel = model.vision_tower else: diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index ce06f767a..f5b5e6355 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -512,10 +512,10 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation( # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(qwen3_vl_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( LigerRMSNorm.forward ) - for decoder_layer in dummy_model_instance.language_model.layers: + for decoder_layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( LigerRMSNorm.forward @@ -532,10 +532,10 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation( # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(qwen3_vl_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( LigerRMSNorm.forward ) - for decoder_layer in dummy_model_instance.language_model.layers: + for decoder_layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( LigerRMSNorm.forward @@ -791,10 +791,10 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe_for_conditional_generat # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(qwen3_vl_moe_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( LigerRMSNorm.forward ) - for decoder_layer in dummy_model_instance.language_model.layers: + for decoder_layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( LigerRMSNorm.forward @@ -811,10 +811,10 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe_for_conditional_generat # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(qwen3_vl_moe_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( LigerRMSNorm.forward ) - for decoder_layer in dummy_model_instance.language_model.layers: + for decoder_layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( LigerRMSNorm.forward @@ -1130,10 +1130,10 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(mllama_lce_forward) - if isinstance(dummy_model_instance.language_model, MllamaTextModel): - language_model = dummy_model_instance.language_model + if isinstance(dummy_model_instance.model.language_model, MllamaTextModel): + language_model = dummy_model_instance.model.language_model else: - language_model = dummy_model_instance.language_model.model + language_model = dummy_model_instance.model.language_model.model assert inspect.getsource(language_model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in language_model.layers: @@ -1141,18 +1141,18 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(dummy_model_instance.vision_model.layernorm_pre.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.vision_model.layernorm_pre.forward) != inspect.getsource( LigerLayerNorm.forward ) - assert inspect.getsource(dummy_model_instance.vision_model.layernorm_post.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.vision_model.layernorm_post.forward) != inspect.getsource( LigerLayerNorm.forward ) - for layer in dummy_model_instance.vision_model.transformer.layers: + for layer in dummy_model_instance.model.vision_model.transformer.layers: assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource( LigerLayerNorm.forward ) - for layer in dummy_model_instance.vision_model.global_transformer.layers: + for layer in dummy_model_instance.model.vision_model.global_transformer.layers: assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource( LigerLayerNorm.forward @@ -1169,18 +1169,18 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(dummy_model_instance.vision_model.layernorm_pre.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.vision_model.layernorm_pre.forward) == inspect.getsource( LigerLayerNorm.forward ) - assert inspect.getsource(dummy_model_instance.vision_model.layernorm_post.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.vision_model.layernorm_post.forward) == inspect.getsource( LigerLayerNorm.forward ) - for layer in dummy_model_instance.vision_model.transformer.layers: + for layer in dummy_model_instance.model.vision_model.transformer.layers: assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource( LigerLayerNorm.forward ) - for layer in dummy_model_instance.vision_model.global_transformer.layers: + for layer in dummy_model_instance.model.vision_model.global_transformer.layers: assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource( LigerLayerNorm.forward @@ -1588,10 +1588,10 @@ def test_apply_liger_kernel_to_instance_for_paligemma(): # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(paligemma_lce_forward) assert inspect.getsource( - dummy_model_instance.vision_tower.vision_model.post_layernorm.forward + dummy_model_instance.model.vision_tower.vision_model.post_layernorm.forward ) != inspect.getsource(LigerLayerNorm.forward) - for layer in dummy_model_instance.vision_tower.vision_model.encoder.layers: + for layer in dummy_model_instance.model.vision_tower.vision_model.encoder.layers: assert inspect.getsource(layer.layer_norm1.forward) != inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(layer.layer_norm2.forward) != inspect.getsource(LigerLayerNorm.forward) @@ -1601,10 +1601,10 @@ def test_apply_liger_kernel_to_instance_for_paligemma(): # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(paligemma_lce_forward) assert inspect.getsource( - dummy_model_instance.vision_tower.vision_model.post_layernorm.forward + dummy_model_instance.model.vision_tower.vision_model.post_layernorm.forward ) == inspect.getsource(LigerLayerNorm.forward) - for layer in dummy_model_instance.vision_tower.vision_model.encoder.layers: + for layer in dummy_model_instance.model.vision_tower.vision_model.encoder.layers: assert inspect.getsource(layer.layer_norm1.forward) == inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(layer.layer_norm2.forward) == inspect.getsource(LigerLayerNorm.forward) @@ -1697,22 +1697,22 @@ def test_apply_liger_kernel_to_instance_for_gemma3_conditional_generation(): # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(gemma3_multimodal_forward) assert inspect.getsource( - dummy_model_instance.vision_tower.vision_model.post_layernorm.forward + dummy_model_instance.model.vision_tower.vision_model.post_layernorm.forward ) != inspect.getsource(LigerLayerNorm.forward) - for layer in dummy_model_instance.vision_tower.vision_model.encoder.layers: + for layer in dummy_model_instance.model.vision_tower.vision_model.encoder.layers: assert inspect.getsource(layer.layer_norm1.forward) != inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(layer.layer_norm2.forward) != inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource( - dummy_model_instance.multi_modal_projector.mm_soft_emb_norm.forward + dummy_model_instance.model.multi_modal_projector.mm_soft_emb_norm.forward ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) @@ -1729,21 +1729,21 @@ def test_apply_liger_kernel_to_instance_for_gemma3_conditional_generation(): # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(gemma3_multimodal_forward) assert inspect.getsource( - dummy_model_instance.vision_tower.vision_model.post_layernorm.forward + dummy_model_instance.model.vision_tower.vision_model.post_layernorm.forward ) == inspect.getsource(LigerLayerNorm.forward) - for layer in dummy_model_instance.vision_tower.vision_model.encoder.layers: + for layer in dummy_model_instance.model.vision_tower.vision_model.encoder.layers: assert inspect.getsource(layer.layer_norm1.forward) == inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(layer.layer_norm2.forward) == inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource( - dummy_model_instance.multi_modal_projector.mm_soft_emb_norm.forward + dummy_model_instance.model.multi_modal_projector.mm_soft_emb_norm.forward ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) @@ -1920,14 +1920,14 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation( # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(qwen2_vl_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - for vision_block in dummy_model_instance.visual.blocks: + for vision_block in dummy_model_instance.model.visual.blocks: assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerLayerNorm.forward) @@ -1936,14 +1936,14 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation( # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(qwen2_vl_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - for vision_block in dummy_model_instance.visual.blocks: + for vision_block in dummy_model_instance.model.visual.blocks: assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerLayerNorm.forward) assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerLayerNorm.forward) @@ -2178,14 +2178,14 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_for_conditional_generatio # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(qwen2_5_vl_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - for vision_block in dummy_model_instance.visual.blocks: + for vision_block in dummy_model_instance.model.visual.blocks: assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerRMSNorm.forward) @@ -2194,14 +2194,14 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_for_conditional_generatio # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(qwen2_5_vl_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - for vision_block in dummy_model_instance.visual.blocks: + for vision_block in dummy_model_instance.model.visual.blocks: assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerRMSNorm.forward) @@ -2300,10 +2300,10 @@ def test_apply_liger_kernel_to_instance_for_internvl(): assert isinstance(dummy_model_instance, InternVLForConditionalGeneration) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) @@ -2312,10 +2312,10 @@ def test_apply_liger_kernel_to_instance_for_internvl(): _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) @@ -2609,16 +2609,16 @@ def test_apply_liger_kernel_to_instance_for_glm4v(): # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4v_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerPhi3SwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_self_attn_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_mlp_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - for vision_block in dummy_model_instance.visual.blocks: + for vision_block in dummy_model_instance.model.visual.blocks: assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(vision_block.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) @@ -2628,16 +2628,16 @@ def test_apply_liger_kernel_to_instance_for_glm4v(): # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4v_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( LigerRMSNorm.forward ) - for layer in dummy_model_instance.language_model.layers: + for layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerPhi3SwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_self_attn_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_mlp_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - for vision_block in dummy_model_instance.visual.blocks: + for vision_block in dummy_model_instance.model.visual.blocks: assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(vision_block.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) @@ -2684,17 +2684,17 @@ def test_apply_liger_kernel_to_instance_for_glm4v_moe(): # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4v_moe_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( LigerRMSNormForGlm4.forward ) - assert inspect.getsource(dummy_model_instance.visual.post_conv_layernorm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.visual.post_conv_layernorm.forward) != inspect.getsource( LigerRMSNormForGlm4.forward ) - assert inspect.getsource(dummy_model_instance.visual.post_layernorm.forward) != inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.visual.post_layernorm.forward) != inspect.getsource( LigerRMSNormForGlm4.forward ) - for decoder_layer in dummy_model_instance.language_model.layers: + for decoder_layer in dummy_model_instance.model.language_model.layers: assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( LigerRMSNormForGlm4.forward @@ -2709,7 +2709,7 @@ def test_apply_liger_kernel_to_instance_for_glm4v_moe(): assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) != inspect.getsource( LigerSwiGLUMLP.forward ) - for vision_block in dummy_model_instance.visual.blocks: + for vision_block in dummy_model_instance.model.visual.blocks: assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerRMSNormForGlm4.forward) assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerRMSNormForGlm4.forward) assert inspect.getsource(vision_block.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) @@ -2719,17 +2719,17 @@ def test_apply_liger_kernel_to_instance_for_glm4v_moe(): # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4v_moe_lce_forward) - assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( LigerRMSNormForGlm4.forward ) - assert inspect.getsource(dummy_model_instance.visual.post_conv_layernorm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.visual.post_conv_layernorm.forward) == inspect.getsource( LigerRMSNormForGlm4.forward ) - assert inspect.getsource(dummy_model_instance.visual.post_layernorm.forward) == inspect.getsource( + assert inspect.getsource(dummy_model_instance.model.visual.post_layernorm.forward) == inspect.getsource( LigerRMSNormForGlm4.forward ) - for decoder_layer in dummy_model_instance.language_model.layers: + for decoder_layer in dummy_model_instance.model.language_model.layers: if decoder_layer.mlp is not None: assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( @@ -2745,7 +2745,7 @@ def test_apply_liger_kernel_to_instance_for_glm4v_moe(): assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) == inspect.getsource( LigerSwiGLUMLP.forward ) - for vision_block in dummy_model_instance.visual.blocks: + for vision_block in dummy_model_instance.model.visual.blocks: assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerRMSNormForGlm4.forward) assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerRMSNormForGlm4.forward) assert inspect.getsource(vision_block.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward)