Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/axolotl/loaders/patch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def _apply_mistral_cross_entropy_patch(self):
def _apply_self_attention_lora_patch(self):
"""Apply self-attention LoRA patches if configured."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
# Only patch if conditions are met
can_patch = (
self.cfg.lora_dropout == 0
if hasattr(self.cfg, "lora_dropout")
else True
) # default to True if lora_dropout is not set

if not can_patch:
LOG.warning("Cannot patch self-attention - requires no dropout")
return

from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora

patch_self_attn_lora(self.cfg)
Comment thread
NanoCode012 marked this conversation as resolved.
Expand Down
35 changes: 24 additions & 11 deletions src/axolotl/monkeypatch/lora_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,29 @@ def find_mlp_in_layer(
)


def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
"""
Get the layers of the model. Handles text-only and multimodal models.

Args:
model: A PEFT model.

Returns:
A list of layers.
"""
pretrained_model = model.model

# check for multimodal models first
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
return pretrained_model.model.layers

raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)


def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
Expand Down Expand Up @@ -340,17 +363,7 @@ def apply_lora_kernel_patches(
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")

layers = []
# check for multimodal models first
pretrained_model = model.model
if hasattr(pretrained_model, "language_model"):
layers = pretrained_model.language_model.layers
elif hasattr(pretrained_model, "model"):
layers = pretrained_model.model.layers
else:
raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)
layers = get_layers(model)

# Patch each layer
for layer in layers:
Expand Down
62 changes: 62 additions & 0 deletions tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
find_self_attn_in_layer,
get_attention_cls_from_config,
get_layers,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault
Expand Down Expand Up @@ -501,3 +503,63 @@ def test_kernel_training_integration_auto_enable(temp_dir):
break

assert found_patched_attn


def test_kernel_training_integration_dropout_non_zero():
"""Test model loading with dropout non-zero should not patch."""

from axolotl.cli.utils import load_model_and_tokenizer

# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.1,
"lora_target_linear": True,
"sequence_len": 1024,
}
)

# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)

# Store original state before patching
original_forward_method = attention_cls.forward

# Load model
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)

# We call modelloader as that's where the patches are applied
# despite the fact that we're not using it to load the model
model_loader = ModelLoader(cfg, tokenizer)

# Apply patch
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access

# Verify patch was not applied
assert attention_cls.forward == original_forward_method

# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
model
)

# Verify patch was not applied
layers = get_layers(model)
for layer in layers:
for self_attn in find_self_attn_in_layer(layer):
assert not hasattr(self_attn, "apply_qkv")
assert not hasattr(self_attn, "apply_o")