Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
292818e
moe 4 bit quant for 3d packing , downstream
ved1beta Feb 9, 2026
8f81839
Merge branch 'main' into 3d_oom_tv5
ved1beta Feb 9, 2026
d8c0592
exclude moe_params , + reviews
ved1beta Feb 10, 2026
ce1b473
Merge branch '3d_oom_tv5' of github.com:ved1beta/axolotl into 3d_oom_tv5
ved1beta Feb 10, 2026
c69201e
use targate parameters for moe
ved1beta Feb 11, 2026
38f7987
patch with moe_quant revert
ved1beta Feb 11, 2026
5a81c15
adpter exclude modules
ved1beta Feb 11, 2026
44eaef2
detected_expert_params
ved1beta Feb 11, 2026
4d46469
r".*\.parametrizations\..*"
ved1beta Feb 11, 2026
1bab4c1
comment
ved1beta Feb 12, 2026
e97b14e
config
ved1beta Feb 12, 2026
13d85b6
Update src/axolotl/loaders/adapter.py
ved1beta Feb 13, 2026
013d8f0
Merge branch 'main' into 3d_oom_tv5
ved1beta Feb 13, 2026
c3c6893
lint
ved1beta Feb 13, 2026
d16a853
fix: simplify defaults
NanoCode012 Feb 13, 2026
82dad00
true
ved1beta Feb 13, 2026
4d7da67
Merge branch '3d_oom_tv5' of github.com:ved1beta/axolotl into HEAD
ved1beta Feb 13, 2026
3da012f
used keywords exp_proj, down_proj, gate_proj
ved1beta Feb 16, 2026
41d30ff
Update examples/glm4.7/glm4.7-flash-qlora.yaml
ved1beta Feb 16, 2026
dc7caa1
Merge branch 'main' into 3d_oom_tv5
ved1beta Feb 16, 2026
6c734e9
use lora_target_parameters
ved1beta Feb 16, 2026
004ba8f
Merge branch '3d_oom_tv5' of github.com:ved1beta/axolotl into HEAD
ved1beta Feb 16, 2026
2fe6f40
rmv lora_qkv_kernel: false
ved1beta Feb 18, 2026
91b5aa7
support lora _o_proj
ved1beta Feb 18, 2026
e0b7e93
chore: lint
winglian Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/glm4.7/glm4.7-flash-qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
base_model: zai-org/GLM-4.7-Flash

plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

load_in_4bit: true

datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value

val_set_size: 0.0
output_dir: ./outputs/out

adapter: qlora

sequence_len: 32768
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 32
lora_dropout: 0.0
lora_target_linear: true
# MoE expert weights (gate_up_proj, down_proj) are fused 3D tensors in this
# model and are NOT nn.Linear — target them via lora_target_parameters below.
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj

@NanoCode012 NanoCode012 Feb 16, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should explicitly set the lora target parameters so it's clear that it's being trained on here.

It doesn't seem possible to Not target those layers as well.It seems to always be on

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done !

lora_mlp_kernel: true
lora_qkv_kernel: false
lora_o_kernel: true

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0001
max_grad_norm: 1.0

bf16: auto

resume_from_checkpoint:
logging_steps: 1
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 4
save_total_limit: 4

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false

deepspeed:
54 changes: 53 additions & 1 deletion src/axolotl/loaders/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,65 @@ def find_all_linear_names(model):
return list(lora_module_names)


def find_moe_expert_param_names(model: PreTrainedModel) -> list[str]:
"""Detect 3D+ nn.Parameter tensors for PEFT target_parameters.

In transformers v5, MoE models store expert weights as fused 3D nn.Parameter
tensors (num_experts, dim1, dim2) instead of individual nn.Linear modules.
PEFT's target_modules can't target these, but target_parameters can via the
ParamWrapper class which applies LoRA directly to nn.Parameter tensors.

Returns a deduplicated list of parameter path suffixes (e.g.,
["mlp.experts.gate_up_proj", "mlp.experts.down_proj"]) suitable for
PEFT's LoraConfig target_parameters.
"""
seen_suffixes = set()
for name, param in model.named_parameters():
if param.ndim >= 3 and any(
kw in name for kw in ("experts", "gate_up_proj", "down_proj")
):
parts = name.split(".")
# Find the layer index (first numeric segment) and extract the
# repeating suffix after it.
# e.g. "model.layers.0.mlp.experts.gate_up_proj" -> "mlp.experts.gate_up_proj"
Comment on lines +87 to +90

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have some checks for the word "experts" or gate_up_proj / down_proj. They seem to be the common names used.

for i, part in enumerate(parts):
if part.isdigit():
suffix = ".".join(parts[i + 1 :])
seen_suffixes.add(suffix)
break
return sorted(seen_suffixes)


def load_lora(
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []

# Auto-detect MoE expert params for PEFT target_parameters (v5 3D nn.Parameter).
lora_target_parameters = cfg.lora_target_parameters
if lora_target_parameters is None:
detected_expert_params = getattr(
model, "_moe_expert_param_names", None
) or find_moe_expert_param_names(model)
if detected_expert_params:
LOG.info(
"Auto-detected MoE expert parameters for LoRA target_parameters: %s",
detected_expert_params,
)
lora_target_parameters = detected_expert_params
else:
lora_target_parameters = []
elif isinstance(lora_target_parameters, str):
lora_target_parameters = [lora_target_parameters]

# Exclude ParametrizationList submodules created by replace_parameter_4bit
# from target_modules matching (regex needed — "parametrizations" is mid-path).
exclude_modules = None
if getattr(model, "_moe_experts_quantized", False):
exclude_modules = r".*\.parametrizations\..*"

if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
Expand Down Expand Up @@ -119,6 +170,7 @@ def load_lora(
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
target_parameters=lora_target_parameters,
exclude_modules=exclude_modules,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
Expand Down
18 changes: 18 additions & 0 deletions src/axolotl/loaders/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,20 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()

# Quantize 3D MoE expert nn.Parameter tensors that BnB skips.
# Detect names before quantization (replace_parameter_4bit changes them).
self.model._moe_experts_quantized = False
self.model._moe_expert_param_names = []
if self.cfg.adapter in ("qlora", "lora") and (
self.cfg.load_in_4bit or self.cfg.load_in_8bit
):
from axolotl.loaders.adapter import find_moe_expert_param_names
from axolotl.monkeypatch.moe_quant import quantize_moe_expert_params

self.model._moe_expert_param_names = find_moe_expert_param_names(self.model)
self.model._moe_experts_quantized = quantize_moe_expert_params(self.model)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called for LoRa, despite the inner function calling replace_parameter_4bit specifically. Do we need to have a specific replace_parameter_8bit for lora ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

narrowed the guard to load_in_4bit only now


PLUGIN_MANAGER.post_model_build(self.cfg, self.model)

# Post-build model configuration
Expand Down Expand Up @@ -860,6 +874,10 @@ def _prepare_model_for_quantization(self):
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True

if getattr(self.model, "_moe_experts_quantized", False):
# Parametrized expert tensors dequantize on access — would OOM.
skip_prepare_model_for_kbit_training = True

if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
Expand Down
75 changes: 45 additions & 30 deletions src/axolotl/monkeypatch/lora_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,18 @@ def patch_self_attn_lora(cfg: DictDefault):
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)

assert any(qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES), (
"Original QKV code not found"
)
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"

for qkv_orig, qkv_patched in QKV_PATCHES:
if qkv_orig in self_attn_forward:
self_attn_forward = self_attn_forward.replace(
qkv_orig,
qkv_patched,
)
break
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
if cfg.lora_qkv_kernel:
assert any(
qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES
), "Original QKV code not found"
for qkv_orig, qkv_patched in QKV_PATCHES:
if qkv_orig in self_attn_forward:
self_attn_forward = self_attn_forward.replace(qkv_orig, qkv_patched)
break

if cfg.lora_o_kernel:
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
Expand Down Expand Up @@ -249,6 +248,12 @@ def find_self_attn_in_layer(
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
):
yield layer.self_attn
# MLA attention (DeepSeek-V2/V3, GLM-4.7): no q/k/v_proj, but o_proj is standard
elif all(
hasattr(layer.self_attn, proj)
for proj in ["kv_a_proj_with_mqa", "kv_b_proj", "o_proj"]
):
yield layer.self_attn


def find_mlp_in_layer(
Expand Down Expand Up @@ -388,25 +393,35 @@ def apply_lora_kernel_patches(
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)

if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)

if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
# Query, key, value patching — only for standard QKV models, not MLA
if not all(
hasattr(self_attn, p) for p in ["q_proj", "k_proj", "v_proj"]
):
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Skipping QKV kernel patch — model uses MLA attention "
"(no q_proj/k_proj/v_proj). Disable lora_qkv_kernel to silence."
)
else:
layer_modules = [
getattr(self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)

if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
Expand Down
86 changes: 86 additions & 0 deletions src/axolotl/monkeypatch/moe_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors.

In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
nn.Linear, so these expert weights are skipped during model loading, causing OOM.

This module provides a post-load fixup that quantizes those skipped parameters using
bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0).
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
params via stacked parametrizations.
"""

import bitsandbytes as bnb
import torch

from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


def find_unquantized_expert_params(model):
"""Find 3D+ nn.Parameter tensors that BnB quantization skipped.

Returns:
List of (module, param_name) tuples to quantize.
"""
params_to_quantize = []
for _, module in model.named_modules():
if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
continue
for param_name, param in module.named_parameters(recurse=False):
if param.ndim >= 3 and any(
kw in param_name for kw in ("experts", "gate_up_proj", "down_proj")
):
params_to_quantize.append((module, param_name))
return params_to_quantize


def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None):
"""Quantize 3D nn.Parameter expert weights that BnB skips during model loading.

Reads quant_type and compress_statistics from the model's quantization_config
when not explicitly provided, so that the same settings used for nn.Linear
quantization are applied to the MoE expert parameters.
"""
from bitsandbytes.nn.parametrize import replace_parameter_4bit

params_to_quantize = find_unquantized_expert_params(model)
if not params_to_quantize:
return False

# Derive settings from model's BnB config if not explicitly provided
if quant_type is None or compress_statistics is None:
bnb_config = getattr(model.config, "quantization_config", None)
if bnb_config is not None:
if quant_type is None:
quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4")
if compress_statistics is None:
compress_statistics = getattr(
bnb_config, "bnb_4bit_use_double_quant", True
)
# Final defaults
if quant_type is None:
quant_type = "nf4"
if compress_statistics is None:
compress_statistics = True

count = 0
for module, param_name in params_to_quantize:
replace_parameter_4bit(
module,
param_name,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
count += 1

torch.cuda.empty_cache()
LOG.info(
"Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)",
count,
quant_type,
compress_statistics,
)
return True