Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def apply_liger_kernel_to_llava(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = model.language_model
text_kwargs["model"] = model.model.language_model
Comment thread
Tcc0403 marked this conversation as resolved.
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
Expand All @@ -445,7 +445,7 @@ def apply_liger_kernel_to_llava(
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
)
vision_kwargs["model"] = model.vision_tower
vision_kwargs["model"] = model.model.vision_tower
Comment thread
Tcc0403 marked this conversation as resolved.
vision_liger_fn(**vision_kwargs)
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
Expand Down Expand Up @@ -615,8 +615,8 @@ def apply_liger_kernel_to_mllama(
# instance variables that reference already-instantiated modules

if isinstance(model, MllamaForConditionalGeneration):
language_model: MllamaForCausalLM = model.language_model
vision_model: MllamaVisionModel = model.vision_model
language_model: MllamaForCausalLM = model.model.language_model
vision_model: MllamaVisionModel = model.model.vision_model
if isinstance(language_model, MllamaForCausalLM):
text_model: MllamaTextModel = language_model.model
else:
Expand Down Expand Up @@ -1118,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
# instance variables that reference already-instantiated modules

if isinstance(model, Gemma3ForConditionalGeneration):
if isinstance(model.vision_tower, SiglipVisionModel):
vision_tower = model.vision_tower
if isinstance(model.model.vision_tower, SiglipVisionModel):
vision_tower = model.model.vision_tower

_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)

Expand All @@ -1132,15 +1132,15 @@ def apply_liger_kernel_to_gemma3(
raise TypeError("The vision tower must be SiglipVisionModel")

if rms_norm:
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)

apply_liger_kernel_to_gemma3_text(
rope=rope,
cross_entropy=False,
fused_linear_cross_entropy=False,
rms_norm=rms_norm,
geglu=geglu,
model=model.language_model,
model=model.model.language_model,
)

else:
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
if not isinstance(model, PaliGemmaForConditionalGeneration):
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")

vision_tower: SiglipVisionModel = model.vision_tower
vision_tower: SiglipVisionModel = model.model.vision_tower

_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)

Expand All @@ -1238,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
_patch_layer_norm_module(layer.layer_norm1)
_patch_layer_norm_module(layer.layer_norm2)

language_model = model.language_model
language_model = model.model.language_model

if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
apply_liger_kernel_to_gemma(
Expand Down Expand Up @@ -1525,8 +1525,8 @@ def apply_liger_kernel_to_qwen2_vl(
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me remove this comment? Thanks!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

text_model: Qwen2VLTextModel = model.language_model
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
text_model: Qwen2VLTextModel = model.model.language_model
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
elif isinstance(model, Qwen2VLTextModel):
text_model: Qwen2VLTextModel = model
vision_model = None
Expand Down Expand Up @@ -1616,8 +1616,8 @@ def apply_liger_kernel_to_qwen2_5_vl(
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
text_model: Qwen2_5_VLTextModel = model.language_model
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
text_model: Qwen2_5_VLTextModel = model.model.language_model
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
elif isinstance(model, Qwen2_5_VLTextModel):
text_model: Qwen2_5_VLTextModel = model
vision_model = None
Expand All @@ -1629,7 +1629,7 @@ def apply_liger_kernel_to_qwen2_5_vl(

if vision_model is not None:
# Patch Qwen2_5_VisionTransformerPretrainedModel
for vision_block in model.visual.blocks:
for vision_block in model.model.visual.blocks:
if rms_norm:
_patch_rms_norm_module(vision_block.norm1)
_patch_rms_norm_module(vision_block.norm2)
Expand Down Expand Up @@ -1699,7 +1699,7 @@ def apply_liger_kernel_to_qwen3_vl(

if model is not None and rms_norm:
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
text_model: Qwen3VLTextModel = model.language_model
text_model: Qwen3VLTextModel = model.model.language_model
elif isinstance(model, Qwen3VLTextModel):
text_model = model
else:
Expand Down Expand Up @@ -1774,7 +1774,7 @@ def apply_liger_kernel_to_qwen3_vl_moe(

if model is not None and rms_norm:
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
text_model: Qwen3VLMoeTextModel = model.language_model
text_model: Qwen3VLMoeTextModel = model.model.language_model
elif isinstance(model, Qwen3VLMoeTextModel):
text_model = model
else:
Expand Down Expand Up @@ -2122,8 +2122,8 @@ def apply_liger_kernel_to_glm4v(
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
text_model: Glm4vTextModel = model.language_model
vision_model: Glm4vVisionModel = model.visual
text_model: Glm4vTextModel = model.model.language_model
vision_model: Glm4vVisionModel = model.model.visual
elif isinstance(model, Glm4vTextModel):
text_model: Glm4vTextModel = model
vision_model = None
Expand Down Expand Up @@ -2212,8 +2212,8 @@ def apply_liger_kernel_to_glm4v_moe(
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
text_model: Glm4vMoeTextModel = model.language_model
vision_model: Glm4vMoeVisionModel = model.visual
text_model: Glm4vMoeTextModel = model.model.language_model
vision_model: Glm4vMoeVisionModel = model.model.visual
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
elif isinstance(model, Glm4vMoeTextModel):
text_model: Glm4vMoeTextModel = model
Expand Down Expand Up @@ -2316,8 +2316,8 @@ def apply_liger_kernel_to_internvl(
# instance variables that reference already-instantiated modules
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
# NOTE: language_model and visual properties can be accessed throught conditional class.
text_model = model.language_model
vision_model: InternVLVisionModel = model.vision_tower
text_model = model.model.language_model
vision_model: InternVLVisionModel = model.model.vision_tower
else:
raise TypeError(
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
Expand Down