From 292818e1ac592b706db3fa0333f34eb96ae7431a Mon Sep 17 00:00:00 2001 From: ved1beta Date: Mon, 9 Feb 2026 22:32:01 +0530 Subject: [PATCH 01/19] moe 4 bit quant for 3d packing , downstream --- src/axolotl/common/architectures.py | 43 +++++++++++++++++++++++++++ src/axolotl/loaders/adapter.py | 12 ++++++++ src/axolotl/loaders/model.py | 25 +++++++++++++++- src/axolotl/monkeypatch/moe_quant.py | 44 ++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/monkeypatch/moe_quant.py diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index f4d6ca9287..664351b938 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -19,3 +19,46 @@ "lfm2_moe": "Lfm2MoeSparseMoeBlock", "afmoe": "AfmoeMoE", } + +# MoE architectures whose expert weights are 3D nn.Parameter tensors (not nn.Linear). +# BnB 4-bit quantization skips these by default, causing OOM. This mapping provides +# the parameter names needed for `target_parameters` in BitsAndBytesConfig or for +# post-load quantization via bitsandbytes.nn.parametrize. +# Verified against transformers 5.0.0 source. +MOE_EXPERT_PARAMS = { + # gate_up_proj/down_proj pattern: (num_experts, 2*intermediate, hidden) / (num_experts, hidden, intermediate) + "deepseek_v2": ["gate_up_proj", "down_proj"], + "deepseek_v3": ["gate_up_proj", "down_proj"], + "dots1": ["gate_up_proj", "down_proj"], + "ernie4_5_moe": ["gate_up_proj", "down_proj"], + "ernie4_5_vl_moe": ["gate_up_proj", "down_proj"], + "flex_olmo": ["gate_up_proj", "down_proj"], + "glm4_moe": ["gate_up_proj", "down_proj"], + "glm4_moe_lite": ["gate_up_proj", "down_proj"], + "glm4v_moe": ["gate_up_proj", "down_proj"], + "hunyuan_v1_moe": ["gate_up_proj", "down_proj"], + "jamba": ["gate_up_proj", "down_proj"], + "lfm2_moe": ["gate_up_proj", "down_proj"], + "llama4": ["gate_up_proj", "down_proj"], + "longcat_flash": ["gate_up_proj", "down_proj"], + "minimax": ["gate_up_proj", "down_proj"], + "minimax_m2": ["gate_up_proj", "down_proj"], + "mixtral": ["gate_up_proj", "down_proj"], + "olmoe": ["gate_up_proj", "down_proj"], + "phimoe": ["gate_up_proj", "down_proj"], + "qwen2_moe": ["gate_up_proj", "down_proj"], + "qwen3_moe": ["gate_up_proj", "down_proj"], + "qwen3_next": ["gate_up_proj", "down_proj"], + "qwen3_omni_moe": ["gate_up_proj", "down_proj"], + "qwen3_vl_moe": ["gate_up_proj", "down_proj"], + "solar_open": ["gate_up_proj", "down_proj"], + # gate_up_proj/down_proj + bias params + "gpt_oss": ["gate_up_proj", "down_proj"], + # weight-only pattern: (num_experts, output_size, input_size) + "jetmoe": ["weight"], + "granitemoe": ["weight"], + "granitemoehybrid": ["weight"], + "granitemoeshared": ["weight"], + # dbrx uses different param names: w1, v1, w2 (2D packed) + "dbrx": ["w1", "v1", "w2"], +} diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 3b64b23db9..7273e17633 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -19,6 +19,7 @@ ) from transformers import PreTrainedModel +from axolotl.common.architectures import MOE_EXPERT_PARAMS from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault @@ -114,11 +115,22 @@ def load_lora( else: task_type = TaskType.CAUSAL_LM + # Exclude ParametrizationList modules created by MoE expert quantization. + # replace_parameter_4bit wraps quantized params in ParametrizationList child + # modules that PEFT doesn't support as LoRA targets. + exclude_modules = cfg.lora_exclude_modules or [] + if cfg.model_config_type in MOE_EXPERT_PARAMS: + expert_param_names = MOE_EXPERT_PARAMS[cfg.model_config_type] + for name in expert_param_names: + if name not in exclude_modules: + exclude_modules.append(name) + lora_config = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, target_parameters=lora_target_parameters, + exclude_modules=exclude_modules if exclude_modules else None, layers_to_transform=cfg.peft_layers_to_transform, layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 0133148ebd..f78d6b2ea6 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -37,7 +37,7 @@ is_deepspeed_zero3_enabled, ) -from axolotl.common.architectures import MOE_ARCH_BLOCK +from axolotl.common.architectures import MOE_ARCH_BLOCK, MOE_EXPERT_PARAMS from axolotl.integrations.base import PluginManager from axolotl.loaders.adapter import load_adapter, load_lora from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING @@ -173,6 +173,29 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non PLUGIN_MANAGER.pre_model_load(self.cfg) self.patch_manager.apply_post_plugin_pre_model_load_patches() skip_move_to_device = self._build_model() + + # Quantize MoE expert weights immediately after model build. + # In transformers v5, MoE expert weights are 3D nn.Parameter tensors that + # BnB quantization skips (it only handles nn.Linear). This causes OOM because + # expert weights stay in full precision. We quantize them here before any other + # operations that need GPU memory (like prepare_model_for_kbit_training). + if ( + self.cfg.adapter == "qlora" + and self.cfg.load_in_4bit + and self.cfg.model_config_type in MOE_EXPERT_PARAMS + ): + import inspect + + bnb_config_params = inspect.signature( + BitsAndBytesConfig.__init__ + ).parameters + if "target_parameters" not in bnb_config_params: + from axolotl.monkeypatch.moe_quant import ( + quantize_moe_expert_params, + ) + + quantize_moe_expert_params(self.model, self.cfg.model_config_type) + PLUGIN_MANAGER.post_model_build(self.cfg, self.model) # Post-build model configuration diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py new file mode 100644 index 0000000000..df99de5cb1 --- /dev/null +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -0,0 +1,44 @@ +""" +Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors. + +In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter +tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets +nn.Linear, so these expert weights are skipped during model loading, causing OOM. + +This module provides a post-load fixup that quantizes those skipped parameters using +bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). +""" + +import torch + +from axolotl.common.architectures import MOE_EXPERT_PARAMS +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def quantize_moe_expert_params( + model, model_config_type, quant_type="nf4", compress_statistics=True +): + """Quantize 3D nn.Parameter expert weights that BnB skips during model loading.""" + from bitsandbytes.nn.parametrize import replace_parameter_4bit + + target_params = MOE_EXPERT_PARAMS.get(model_config_type) + if not target_params: + return + + count = 0 + for module_name, module in model.named_modules(): + for param_name in target_params: + if hasattr(module, param_name): + param = getattr(module, param_name) + if isinstance(param, torch.nn.Parameter) and param.ndim >= 2: + replace_parameter_4bit( + module, + param_name, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + count += 1 + + LOG.info("Quantized %d MoE expert parameters to 4-bit", count) From d8c059253b4145e11ac29bf6bd4f0a5891f8f23a Mon Sep 17 00:00:00 2001 From: ved1beta Date: Tue, 10 Feb 2026 14:37:50 +0530 Subject: [PATCH 02/19] exclude moe_params , + reviews --- src/axolotl/common/architectures.py | 43 ------------------------- src/axolotl/loaders/adapter.py | 16 +++++----- src/axolotl/loaders/model.py | 10 +++--- src/axolotl/monkeypatch/moe_quant.py | 47 +++++++++++++++++----------- 4 files changed, 41 insertions(+), 75 deletions(-) diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 664351b938..f4d6ca9287 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -19,46 +19,3 @@ "lfm2_moe": "Lfm2MoeSparseMoeBlock", "afmoe": "AfmoeMoE", } - -# MoE architectures whose expert weights are 3D nn.Parameter tensors (not nn.Linear). -# BnB 4-bit quantization skips these by default, causing OOM. This mapping provides -# the parameter names needed for `target_parameters` in BitsAndBytesConfig or for -# post-load quantization via bitsandbytes.nn.parametrize. -# Verified against transformers 5.0.0 source. -MOE_EXPERT_PARAMS = { - # gate_up_proj/down_proj pattern: (num_experts, 2*intermediate, hidden) / (num_experts, hidden, intermediate) - "deepseek_v2": ["gate_up_proj", "down_proj"], - "deepseek_v3": ["gate_up_proj", "down_proj"], - "dots1": ["gate_up_proj", "down_proj"], - "ernie4_5_moe": ["gate_up_proj", "down_proj"], - "ernie4_5_vl_moe": ["gate_up_proj", "down_proj"], - "flex_olmo": ["gate_up_proj", "down_proj"], - "glm4_moe": ["gate_up_proj", "down_proj"], - "glm4_moe_lite": ["gate_up_proj", "down_proj"], - "glm4v_moe": ["gate_up_proj", "down_proj"], - "hunyuan_v1_moe": ["gate_up_proj", "down_proj"], - "jamba": ["gate_up_proj", "down_proj"], - "lfm2_moe": ["gate_up_proj", "down_proj"], - "llama4": ["gate_up_proj", "down_proj"], - "longcat_flash": ["gate_up_proj", "down_proj"], - "minimax": ["gate_up_proj", "down_proj"], - "minimax_m2": ["gate_up_proj", "down_proj"], - "mixtral": ["gate_up_proj", "down_proj"], - "olmoe": ["gate_up_proj", "down_proj"], - "phimoe": ["gate_up_proj", "down_proj"], - "qwen2_moe": ["gate_up_proj", "down_proj"], - "qwen3_moe": ["gate_up_proj", "down_proj"], - "qwen3_next": ["gate_up_proj", "down_proj"], - "qwen3_omni_moe": ["gate_up_proj", "down_proj"], - "qwen3_vl_moe": ["gate_up_proj", "down_proj"], - "solar_open": ["gate_up_proj", "down_proj"], - # gate_up_proj/down_proj + bias params - "gpt_oss": ["gate_up_proj", "down_proj"], - # weight-only pattern: (num_experts, output_size, input_size) - "jetmoe": ["weight"], - "granitemoe": ["weight"], - "granitemoehybrid": ["weight"], - "granitemoeshared": ["weight"], - # dbrx uses different param names: w1, v1, w2 (2D packed) - "dbrx": ["w1", "v1", "w2"], -} diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 7273e17633..8d8d984758 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -19,7 +19,7 @@ ) from transformers import PreTrainedModel -from axolotl.common.architectures import MOE_EXPERT_PARAMS +from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault @@ -118,19 +118,19 @@ def load_lora( # Exclude ParametrizationList modules created by MoE expert quantization. # replace_parameter_4bit wraps quantized params in ParametrizationList child # modules that PEFT doesn't support as LoRA targets. - exclude_modules = cfg.lora_exclude_modules or [] - if cfg.model_config_type in MOE_EXPERT_PARAMS: - expert_param_names = MOE_EXPERT_PARAMS[cfg.model_config_type] - for name in expert_param_names: - if name not in exclude_modules: - exclude_modules.append(name) + # exclude "parametrizations" to skip all such wrapper modules. + exclude_modules = None + if cfg.model_config_type in MOE_ARCH_BLOCK and ( + cfg.load_in_4bit or cfg.load_in_8bit + ): + exclude_modules = ["parametrizations"] lora_config = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, target_parameters=lora_target_parameters, - exclude_modules=exclude_modules if exclude_modules else None, + exclude_modules=exclude_modules, layers_to_transform=cfg.peft_layers_to_transform, layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index f78d6b2ea6..3f496c0fa4 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -37,7 +37,7 @@ is_deepspeed_zero3_enabled, ) -from axolotl.common.architectures import MOE_ARCH_BLOCK, MOE_EXPERT_PARAMS +from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.integrations.base import PluginManager from axolotl.loaders.adapter import load_adapter, load_lora from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING @@ -180,9 +180,9 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non # expert weights stay in full precision. We quantize them here before any other # operations that need GPU memory (like prepare_model_for_kbit_training). if ( - self.cfg.adapter == "qlora" - and self.cfg.load_in_4bit - and self.cfg.model_config_type in MOE_EXPERT_PARAMS + self.cfg.adapter in ("qlora", "lora") + and (self.cfg.load_in_4bit or self.cfg.load_in_8bit) + and self.cfg.model_config_type in MOE_ARCH_BLOCK ): import inspect @@ -194,7 +194,7 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non quantize_moe_expert_params, ) - quantize_moe_expert_params(self.model, self.cfg.model_config_type) + quantize_moe_expert_params(self.model) PLUGIN_MANAGER.post_model_build(self.cfg, self.model) diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index df99de5cb1..b4f6d50c12 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -9,36 +9,45 @@ bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). """ -import torch +import bitsandbytes as bnb -from axolotl.common.architectures import MOE_EXPERT_PARAMS from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -def quantize_moe_expert_params( - model, model_config_type, quant_type="nf4", compress_statistics=True -): +def find_unquantized_expert_params(model): + """Find 3D+ nn.Parameter tensors that BnB quantization skipped. + + Returns: + List of (module, param_name) tuples to quantize. + """ + params_to_quantize = [] + for _, module in model.named_modules(): + if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): + continue + for param_name, param in module.named_parameters(recurse=False): + if param.ndim >= 3: + params_to_quantize.append((module, param_name)) + return params_to_quantize + + +def quantize_moe_expert_params(model, quant_type="nf4", compress_statistics=True): """Quantize 3D nn.Parameter expert weights that BnB skips during model loading.""" from bitsandbytes.nn.parametrize import replace_parameter_4bit - target_params = MOE_EXPERT_PARAMS.get(model_config_type) - if not target_params: + params_to_quantize = find_unquantized_expert_params(model) + if not params_to_quantize: return count = 0 - for module_name, module in model.named_modules(): - for param_name in target_params: - if hasattr(module, param_name): - param = getattr(module, param_name) - if isinstance(param, torch.nn.Parameter) and param.ndim >= 2: - replace_parameter_4bit( - module, - param_name, - compress_statistics=compress_statistics, - quant_type=quant_type, - ) - count += 1 + for module, param_name in params_to_quantize: + replace_parameter_4bit( + module, + param_name, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + count += 1 LOG.info("Quantized %d MoE expert parameters to 4-bit", count) From c69201efadc9fffacf31863ee6554dafe4cf5b25 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Wed, 11 Feb 2026 19:34:53 +0530 Subject: [PATCH 03/19] use targate parameters for moe --- src/axolotl/loaders/adapter.py | 57 +++++++++++++++++++++------- src/axolotl/loaders/model.py | 23 ----------- src/axolotl/monkeypatch/moe_quant.py | 53 -------------------------- 3 files changed, 44 insertions(+), 89 deletions(-) delete mode 100644 src/axolotl/monkeypatch/moe_quant.py diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 8d8d984758..18dff05303 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -19,7 +19,6 @@ ) from transformers import PreTrainedModel -from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault @@ -68,6 +67,33 @@ def find_all_linear_names(model): return list(lora_module_names) +def find_moe_expert_param_names(model): + """Detect 3D+ nn.Parameter tensors for PEFT target_parameters. + + In transformers v5, MoE models store expert weights as fused 3D nn.Parameter + tensors (num_experts, dim1, dim2) instead of individual nn.Linear modules. + PEFT's target_modules can't target these, but target_parameters can via the + ParamWrapper class which applies LoRA directly to nn.Parameter tensors. + + Returns a deduplicated list of parameter path suffixes (e.g., + ["mlp.experts.gate_up_proj", "mlp.experts.down_proj"]) suitable for + PEFT's LoraConfig target_parameters. + """ + seen_suffixes = set() + for name, param in model.named_parameters(): + if param.ndim >= 3: + parts = name.split(".") + # Find the layer index (first numeric segment) and extract the + # repeating suffix after it. + # e.g. "model.layers.0.mlp.experts.gate_up_proj" -> "mlp.experts.gate_up_proj" + for i, part in enumerate(parts): + if part.isdigit(): + suffix = ".".join(parts[i + 1 :]) + seen_suffixes.add(suffix) + break + return sorted(seen_suffixes) + + def load_lora( model: PreTrainedModel, cfg: DictDefault, @@ -75,7 +101,23 @@ def load_lora( config_only: bool = False, ) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: lora_target_modules = cfg.lora_target_modules or [] - lora_target_parameters = cfg.lora_target_parameters or [] + + # In transformers v5, MoE expert weights are 3D + # nn.Parameter tensors that target_modules can't match. PEFT's + # target_parameters (via ParamWrapper) applies LoRA to these directly. + lora_target_parameters = cfg.lora_target_parameters + if lora_target_parameters is None: + detected_params = find_moe_expert_param_names(model) + if detected_params: + LOG.info( + "Auto-detected MoE expert parameters for LoRA target_parameters: %s", + detected_params, + ) + lora_target_parameters = detected_params + else: + lora_target_parameters = [] + elif isinstance(lora_target_parameters, str): + lora_target_parameters = [lora_target_parameters] if cfg.lora_target_linear: linear_names = find_all_linear_names(model) @@ -115,22 +157,11 @@ def load_lora( else: task_type = TaskType.CAUSAL_LM - # Exclude ParametrizationList modules created by MoE expert quantization. - # replace_parameter_4bit wraps quantized params in ParametrizationList child - # modules that PEFT doesn't support as LoRA targets. - # exclude "parametrizations" to skip all such wrapper modules. - exclude_modules = None - if cfg.model_config_type in MOE_ARCH_BLOCK and ( - cfg.load_in_4bit or cfg.load_in_8bit - ): - exclude_modules = ["parametrizations"] - lora_config = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, target_parameters=lora_target_parameters, - exclude_modules=exclude_modules, layers_to_transform=cfg.peft_layers_to_transform, layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 73c5f14788..75684c1ae1 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -173,29 +173,6 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non PLUGIN_MANAGER.pre_model_load(self.cfg) self.patch_manager.apply_post_plugin_pre_model_load_patches() skip_move_to_device = self._build_model() - - # Quantize MoE expert weights immediately after model build. - # In transformers v5, MoE expert weights are 3D nn.Parameter tensors that - # BnB quantization skips (it only handles nn.Linear). This causes OOM because - # expert weights stay in full precision. We quantize them here before any other - # operations that need GPU memory (like prepare_model_for_kbit_training). - if ( - self.cfg.adapter in ("qlora", "lora") - and (self.cfg.load_in_4bit or self.cfg.load_in_8bit) - and self.cfg.model_config_type in MOE_ARCH_BLOCK - ): - import inspect - - bnb_config_params = inspect.signature( - BitsAndBytesConfig.__init__ - ).parameters - if "target_parameters" not in bnb_config_params: - from axolotl.monkeypatch.moe_quant import ( - quantize_moe_expert_params, - ) - - quantize_moe_expert_params(self.model) - PLUGIN_MANAGER.post_model_build(self.cfg, self.model) # Post-build model configuration diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py deleted file mode 100644 index b4f6d50c12..0000000000 --- a/src/axolotl/monkeypatch/moe_quant.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors. - -In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter -tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets -nn.Linear, so these expert weights are skipped during model loading, causing OOM. - -This module provides a post-load fixup that quantizes those skipped parameters using -bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). -""" - -import bitsandbytes as bnb - -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - - -def find_unquantized_expert_params(model): - """Find 3D+ nn.Parameter tensors that BnB quantization skipped. - - Returns: - List of (module, param_name) tuples to quantize. - """ - params_to_quantize = [] - for _, module in model.named_modules(): - if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): - continue - for param_name, param in module.named_parameters(recurse=False): - if param.ndim >= 3: - params_to_quantize.append((module, param_name)) - return params_to_quantize - - -def quantize_moe_expert_params(model, quant_type="nf4", compress_statistics=True): - """Quantize 3D nn.Parameter expert weights that BnB skips during model loading.""" - from bitsandbytes.nn.parametrize import replace_parameter_4bit - - params_to_quantize = find_unquantized_expert_params(model) - if not params_to_quantize: - return - - count = 0 - for module, param_name in params_to_quantize: - replace_parameter_4bit( - module, - param_name, - compress_statistics=compress_statistics, - quant_type=quant_type, - ) - count += 1 - - LOG.info("Quantized %d MoE expert parameters to 4-bit", count) From 38f79875a1b3028a80788899d6eccbc9c851c889 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Wed, 11 Feb 2026 23:26:18 +0530 Subject: [PATCH 04/19] patch with moe_quant revert --- src/axolotl/loaders/adapter.py | 17 +++--- src/axolotl/loaders/model.py | 21 +++++++ src/axolotl/monkeypatch/moe_quant.py | 84 ++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 7 deletions(-) create mode 100644 src/axolotl/monkeypatch/moe_quant.py diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 18dff05303..e130b330e4 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -102,18 +102,21 @@ def load_lora( ) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: lora_target_modules = cfg.lora_target_modules or [] - # In transformers v5, MoE expert weights are 3D - # nn.Parameter tensors that target_modules can't match. PEFT's - # target_parameters (via ParamWrapper) applies LoRA to these directly. + # Auto-detect MoE expert parameters for PEFT target_parameters. + # In transformers v5, MoE expert weights are stored as fused 3D + # nn.Parameter tensors instead of nn.Linear modules. PEFT's + # target_modules can't match these, but target_parameters + ParamWrapper + # can apply LoRA directly -- including when the params have been + # quantized via replace_parameter_4bit (stacked parametrizations). lora_target_parameters = cfg.lora_target_parameters if lora_target_parameters is None: - detected_params = find_moe_expert_param_names(model) - if detected_params: + detected_expert_params = find_moe_expert_param_names(model) + if detected_expert_params: LOG.info( "Auto-detected MoE expert parameters for LoRA target_parameters: %s", - detected_params, + detected_expert_params, ) - lora_target_parameters = detected_params + lora_target_parameters = detected_expert_params else: lora_target_parameters = [] elif isinstance(lora_target_parameters, str): diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 75684c1ae1..8fddc39279 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -173,6 +173,20 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non PLUGIN_MANAGER.pre_model_load(self.cfg) self.patch_manager.apply_post_plugin_pre_model_load_patches() skip_move_to_device = self._build_model() + + # Quantize MoE expert weights that BnB skipped during model loading. + # In transformers v5, MoE expert weights are 3D nn.Parameter tensors + # that BnB quantization skips (it only handles nn.Linear). + # After quantization, PEFT target_parameters applies LoRA on top via + # stacked parametrizations (ParamWrapper). + self.model._moe_experts_quantized = False + if self.cfg.adapter in ("qlora", "lora") and ( + self.cfg.load_in_4bit or self.cfg.load_in_8bit + ): + from axolotl.monkeypatch.moe_quant import quantize_moe_expert_params + + self.model._moe_experts_quantized = quantize_moe_expert_params(self.model) + PLUGIN_MANAGER.post_model_build(self.cfg, self.model) # Post-build model configuration @@ -851,6 +865,13 @@ def _prepare_model_for_quantization(self): # Make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True + if getattr(self.model, "_moe_experts_quantized", False): + # MoE expert weights quantized via replace_parameter_4bit use PyTorch + # parametrize, which causes model.parameters() to return dequantized + # (full-size) tensors. prepare_model_for_kbit_training would OOM trying + # to upcast these to float32. + skip_prepare_model_for_kbit_training = True + if ( not skip_prepare_model_for_kbit_training and self.cfg.adapter in ["lora", "qlora"] diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py new file mode 100644 index 0000000000..a957551212 --- /dev/null +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -0,0 +1,84 @@ +""" +Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors. + +In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter +tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets +nn.Linear, so these expert weights are skipped during model loading, causing OOM. + +This module provides a post-load fixup that quantizes those skipped parameters using +bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). +PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized +params via stacked parametrizations. +""" + +import bitsandbytes as bnb +import torch + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def find_unquantized_expert_params(model): + """Find 3D+ nn.Parameter tensors that BnB quantization skipped. + + Returns: + List of (module, param_name) tuples to quantize. + """ + params_to_quantize = [] + for _, module in model.named_modules(): + if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): + continue + for param_name, param in module.named_parameters(recurse=False): + if param.ndim >= 3: + params_to_quantize.append((module, param_name)) + return params_to_quantize + + +def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None): + """Quantize 3D nn.Parameter expert weights that BnB skips during model loading. + + Reads quant_type and compress_statistics from the model's quantization_config + when not explicitly provided, so that the same settings used for nn.Linear + quantization are applied to the MoE expert parameters. + """ + from bitsandbytes.nn.parametrize import replace_parameter_4bit + + params_to_quantize = find_unquantized_expert_params(model) + if not params_to_quantize: + return False + + # Derive settings from model's BnB config if not explicitly provided + if quant_type is None or compress_statistics is None: + bnb_config = getattr(model.config, "quantization_config", None) + if bnb_config is not None: + if quant_type is None: + quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4") + if compress_statistics is None: + compress_statistics = getattr( + bnb_config, "bnb_4bit_use_double_quant", True + ) + # Final defaults + if quant_type is None: + quant_type = "nf4" + if compress_statistics is None: + compress_statistics = True + + count = 0 + for module, param_name in params_to_quantize: + replace_parameter_4bit( + module, + param_name, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + count += 1 + + torch.cuda.empty_cache() + LOG.info( + "Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)", + count, + quant_type, + compress_statistics, + ) + return True From 5a81c15db82b50ca37385d81284157e3ce04753a Mon Sep 17 00:00:00 2001 From: ved1beta Date: Wed, 11 Feb 2026 23:42:19 +0530 Subject: [PATCH 05/19] adpter exclude modules --- src/axolotl/loaders/adapter.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index e130b330e4..ecd991c2af 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -108,6 +108,10 @@ def load_lora( # target_modules can't match these, but target_parameters + ParamWrapper # can apply LoRA directly -- including when the params have been # quantized via replace_parameter_4bit (stacked parametrizations). + # + # When experts are quantized, replace_parameter_4bit creates + # ParametrizationList submodules that target_modules would incorrectly + # match. Exclude them so only target_parameters handles expert params. lora_target_parameters = cfg.lora_target_parameters if lora_target_parameters is None: detected_expert_params = find_moe_expert_param_names(model) @@ -122,6 +126,10 @@ def load_lora( elif isinstance(lora_target_parameters, str): lora_target_parameters = [lora_target_parameters] + exclude_modules = None + if getattr(model, "_moe_experts_quantized", False) and lora_target_parameters: + exclude_modules = ["parametrizations"] + if cfg.lora_target_linear: linear_names = find_all_linear_names(model) LOG.info(f"found linear modules: {repr(sorted(linear_names))}") @@ -165,6 +173,7 @@ def load_lora( lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, target_parameters=lora_target_parameters, + exclude_modules=exclude_modules, layers_to_transform=cfg.peft_layers_to_transform, layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, From 44eaef20fae8a3ed49764dac29268dc0535d7dd3 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Thu, 12 Feb 2026 00:00:50 +0530 Subject: [PATCH 06/19] detected_expert_params --- src/axolotl/loaders/adapter.py | 9 +++++++-- src/axolotl/loaders/model.py | 3 +++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index ecd991c2af..0bee3a8198 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -114,7 +114,12 @@ def load_lora( # match. Exclude them so only target_parameters handles expert params. lora_target_parameters = cfg.lora_target_parameters if lora_target_parameters is None: - detected_expert_params = find_moe_expert_param_names(model) + # Use pre-quantization names stored by model loader if available, + # since after replace_parameter_4bit the 3D params no longer appear + # as ndim>=3 in named_parameters(). + detected_expert_params = getattr( + model, "_moe_expert_param_names", None + ) or find_moe_expert_param_names(model) if detected_expert_params: LOG.info( "Auto-detected MoE expert parameters for LoRA target_parameters: %s", @@ -127,7 +132,7 @@ def load_lora( lora_target_parameters = [lora_target_parameters] exclude_modules = None - if getattr(model, "_moe_experts_quantized", False) and lora_target_parameters: + if getattr(model, "_moe_experts_quantized", False): exclude_modules = ["parametrizations"] if cfg.lora_target_linear: diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 8fddc39279..69c9493a93 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -180,11 +180,14 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non # After quantization, PEFT target_parameters applies LoRA on top via # stacked parametrizations (ParamWrapper). self.model._moe_experts_quantized = False + self.model._moe_expert_param_names = [] if self.cfg.adapter in ("qlora", "lora") and ( self.cfg.load_in_4bit or self.cfg.load_in_8bit ): + from axolotl.loaders.adapter import find_moe_expert_param_names from axolotl.monkeypatch.moe_quant import quantize_moe_expert_params + self.model._moe_expert_param_names = find_moe_expert_param_names(self.model) self.model._moe_experts_quantized = quantize_moe_expert_params(self.model) PLUGIN_MANAGER.post_model_build(self.cfg, self.model) From 4d46469b4b13131b2ab5dbf27b4b5d232e47fb08 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Thu, 12 Feb 2026 00:08:12 +0530 Subject: [PATCH 07/19] r".*\.parametrizations\..*" --- src/axolotl/loaders/adapter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 0bee3a8198..fe2759e86f 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -131,9 +131,14 @@ def load_lora( elif isinstance(lora_target_parameters, str): lora_target_parameters = [lora_target_parameters] + # When experts are quantized via replace_parameter_4bit, it creates + # ParametrizationList submodules that target_modules would incorrectly + # match. Exclude them so only target_parameters handles expert params. + # Uses regex (string, not list) because "parametrizations" appears in the + # middle of the module path, not as a suffix. exclude_modules = None if getattr(model, "_moe_experts_quantized", False): - exclude_modules = ["parametrizations"] + exclude_modules = r".*\.parametrizations\..*" if cfg.lora_target_linear: linear_names = find_all_linear_names(model) From 1bab4c1683e637359886f1a0fc280a9a82914f2b Mon Sep 17 00:00:00 2001 From: ved1beta Date: Thu, 12 Feb 2026 14:32:42 +0530 Subject: [PATCH 08/19] comment --- src/axolotl/loaders/adapter.py | 21 +++------------------ src/axolotl/loaders/model.py | 12 +++--------- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index fe2759e86f..f5fa881d02 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -102,21 +102,9 @@ def load_lora( ) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: lora_target_modules = cfg.lora_target_modules or [] - # Auto-detect MoE expert parameters for PEFT target_parameters. - # In transformers v5, MoE expert weights are stored as fused 3D - # nn.Parameter tensors instead of nn.Linear modules. PEFT's - # target_modules can't match these, but target_parameters + ParamWrapper - # can apply LoRA directly -- including when the params have been - # quantized via replace_parameter_4bit (stacked parametrizations). - # - # When experts are quantized, replace_parameter_4bit creates - # ParametrizationList submodules that target_modules would incorrectly - # match. Exclude them so only target_parameters handles expert params. + # Auto-detect MoE expert params for PEFT target_parameters (v5 3D nn.Parameter). lora_target_parameters = cfg.lora_target_parameters if lora_target_parameters is None: - # Use pre-quantization names stored by model loader if available, - # since after replace_parameter_4bit the 3D params no longer appear - # as ndim>=3 in named_parameters(). detected_expert_params = getattr( model, "_moe_expert_param_names", None ) or find_moe_expert_param_names(model) @@ -131,11 +119,8 @@ def load_lora( elif isinstance(lora_target_parameters, str): lora_target_parameters = [lora_target_parameters] - # When experts are quantized via replace_parameter_4bit, it creates - # ParametrizationList submodules that target_modules would incorrectly - # match. Exclude them so only target_parameters handles expert params. - # Uses regex (string, not list) because "parametrizations" appears in the - # middle of the module path, not as a suffix. + # Exclude ParametrizationList submodules created by replace_parameter_4bit + # from target_modules matching (regex needed — "parametrizations" is mid-path). exclude_modules = None if getattr(model, "_moe_experts_quantized", False): exclude_modules = r".*\.parametrizations\..*" diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 69c9493a93..7a90cd7010 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -174,11 +174,8 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non self.patch_manager.apply_post_plugin_pre_model_load_patches() skip_move_to_device = self._build_model() - # Quantize MoE expert weights that BnB skipped during model loading. - # In transformers v5, MoE expert weights are 3D nn.Parameter tensors - # that BnB quantization skips (it only handles nn.Linear). - # After quantization, PEFT target_parameters applies LoRA on top via - # stacked parametrizations (ParamWrapper). + # Quantize 3D MoE expert nn.Parameter tensors that BnB skips. + # Detect names before quantization (replace_parameter_4bit changes them). self.model._moe_experts_quantized = False self.model._moe_expert_param_names = [] if self.cfg.adapter in ("qlora", "lora") and ( @@ -869,10 +866,7 @@ def _prepare_model_for_quantization(self): skip_prepare_model_for_kbit_training = True if getattr(self.model, "_moe_experts_quantized", False): - # MoE expert weights quantized via replace_parameter_4bit use PyTorch - # parametrize, which causes model.parameters() to return dequantized - # (full-size) tensors. prepare_model_for_kbit_training would OOM trying - # to upcast these to float32. + # Parametrized expert tensors dequantize on access — would OOM. skip_prepare_model_for_kbit_training = True if ( From e97b14eb4565abc16b42aa607ec97ce598d89069 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Thu, 12 Feb 2026 19:33:29 +0530 Subject: [PATCH 09/19] config --- examples/glm4.7/glm4.7-qlora.yaml | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 examples/glm4.7/glm4.7-qlora.yaml diff --git a/examples/glm4.7/glm4.7-qlora.yaml b/examples/glm4.7/glm4.7-qlora.yaml new file mode 100644 index 0000000000..feffbb8534 --- /dev/null +++ b/examples/glm4.7/glm4.7-qlora.yaml @@ -0,0 +1,72 @@ +base_model: zai-org/GLM-4.7-Flash + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +cut_cross_entropy: true + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + + + +adapter: qlora +save_safetensors: true + +sequence_len: 32768 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 32 +lora_dropout: 0.0 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +lora_mlp_kernel: true +lora_qkv_kernel: false +lora_o_kernel: false + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: rex +learning_rate: 0.00001 +max_grad_norm: 1.0 + +bf16: auto + +wandb_project: glm-test + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 0 +saves_per_epoch: 4 +save_total_limit: 4 + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false + +deepspeed: ./deepspeed_configs/zero2.json \ No newline at end of file From 13d85b648f7e35972b50d927e0b22cde4c0192f5 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Fri, 13 Feb 2026 07:54:35 +0530 Subject: [PATCH 10/19] Update src/axolotl/loaders/adapter.py Co-authored-by: Wing Lian --- src/axolotl/loaders/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index f5fa881d02..e63ab407fe 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -67,7 +67,7 @@ def find_all_linear_names(model): return list(lora_module_names) -def find_moe_expert_param_names(model): +def find_moe_expert_param_names(model: PreTrainedModel) -> list[str]: """Detect 3D+ nn.Parameter tensors for PEFT target_parameters. In transformers v5, MoE models store expert weights as fused 3D nn.Parameter From c3c6893083b22ddd221df2c2afccbc6ed8ecc264 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Fri, 13 Feb 2026 09:28:37 +0530 Subject: [PATCH 11/19] lint --- examples/glm4.7/glm4.7-qlora.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/glm4.7/glm4.7-qlora.yaml b/examples/glm4.7/glm4.7-qlora.yaml index feffbb8534..86e666cc96 100644 --- a/examples/glm4.7/glm4.7-qlora.yaml +++ b/examples/glm4.7/glm4.7-qlora.yaml @@ -69,4 +69,4 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false -deepspeed: ./deepspeed_configs/zero2.json \ No newline at end of file +deepspeed: ./deepspeed_configs/zero2.json From d16a853589a515f370140d76beee1616d3f67206 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 13 Feb 2026 14:57:44 +0700 Subject: [PATCH 12/19] fix: simplify defaults --- .../{glm4.7-qlora.yaml => glm4.7-flash-qlora.yaml} | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) rename examples/glm4.7/{glm4.7-qlora.yaml => glm4.7-flash-qlora.yaml} (85%) diff --git a/examples/glm4.7/glm4.7-qlora.yaml b/examples/glm4.7/glm4.7-flash-qlora.yaml similarity index 85% rename from examples/glm4.7/glm4.7-qlora.yaml rename to examples/glm4.7/glm4.7-flash-qlora.yaml index 86e666cc96..bb969399ea 100644 --- a/examples/glm4.7/glm4.7-qlora.yaml +++ b/examples/glm4.7/glm4.7-flash-qlora.yaml @@ -2,9 +2,7 @@ base_model: zai-org/GLM-4.7-Flash plugins: - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin -cut_cross_entropy: true -load_in_8bit: false load_in_4bit: true datasets: @@ -18,10 +16,7 @@ datasets: val_set_size: 0.0 output_dir: ./outputs/out - - adapter: qlora -save_safetensors: true sequence_len: 32768 sample_packing: true @@ -48,14 +43,12 @@ gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 1 optimizer: adamw_torch_fused -lr_scheduler: rex -learning_rate: 0.00001 +lr_scheduler: cosine +learning_rate: 0.0001 max_grad_norm: 1.0 bf16: auto -wandb_project: glm-test - resume_from_checkpoint: logging_steps: 1 flash_attention: true @@ -69,4 +62,4 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false -deepspeed: ./deepspeed_configs/zero2.json +deepspeed: From 82dad00e4fa884ce25d214d04e607e7e7ba2bd72 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Fri, 13 Feb 2026 15:25:34 +0530 Subject: [PATCH 13/19] true --- examples/glm4.7/glm4.7-qlora.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/glm4.7/glm4.7-qlora.yaml b/examples/glm4.7/glm4.7-qlora.yaml index 86e666cc96..7e6db1a1e8 100644 --- a/examples/glm4.7/glm4.7-qlora.yaml +++ b/examples/glm4.7/glm4.7-qlora.yaml @@ -41,8 +41,8 @@ lora_target_modules: - o_proj lora_mlp_kernel: true -lora_qkv_kernel: false -lora_o_kernel: false +lora_qkv_kernel: true +lora_o_kernel: true gradient_accumulation_steps: 4 micro_batch_size: 1 From 3da012f829589d8c7e92bf6c0e3e6cb73eada3f6 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Mon, 16 Feb 2026 14:31:24 +0530 Subject: [PATCH 14/19] used keywords exp_proj, down_proj, gate_proj --- src/axolotl/loaders/adapter.py | 4 +++- src/axolotl/monkeypatch/moe_quant.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index e63ab407fe..caad2e55e1 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -81,7 +81,9 @@ def find_moe_expert_param_names(model: PreTrainedModel) -> list[str]: """ seen_suffixes = set() for name, param in model.named_parameters(): - if param.ndim >= 3: + if param.ndim >= 3 and any( + kw in name for kw in ("experts", "gate_up_proj", "down_proj") + ): parts = name.split(".") # Find the layer index (first numeric segment) and extract the # repeating suffix after it. diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index a957551212..c9a03e221e 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -30,7 +30,9 @@ def find_unquantized_expert_params(model): if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): continue for param_name, param in module.named_parameters(recurse=False): - if param.ndim >= 3: + if param.ndim >= 3 and any( + kw in param_name for kw in ("experts", "gate_up_proj", "down_proj") + ): params_to_quantize.append((module, param_name)) return params_to_quantize From 41d30ffd8e1cb34809bdef55ebc302feb704258c Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 16 Feb 2026 16:47:19 +0530 Subject: [PATCH 15/19] Update examples/glm4.7/glm4.7-flash-qlora.yaml Co-authored-by: NanoCode012 --- examples/glm4.7/glm4.7-flash-qlora.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/glm4.7/glm4.7-flash-qlora.yaml b/examples/glm4.7/glm4.7-flash-qlora.yaml index f45f40fd6c..3127681ba8 100644 --- a/examples/glm4.7/glm4.7-flash-qlora.yaml +++ b/examples/glm4.7/glm4.7-flash-qlora.yaml @@ -27,9 +27,6 @@ lora_alpha: 32 lora_dropout: 0.0 lora_target_linear: true lora_target_modules: - - gate_proj - - down_proj - - up_proj - q_proj - v_proj - k_proj From 6c734e95c235c81bd99ebbe823fc51d03d071c56 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Mon, 16 Feb 2026 21:04:23 +0530 Subject: [PATCH 16/19] use lora_target_parameters --- examples/glm4.7/glm4.7-flash-qlora.yaml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/glm4.7/glm4.7-flash-qlora.yaml b/examples/glm4.7/glm4.7-flash-qlora.yaml index f45f40fd6c..c8b44c79f7 100644 --- a/examples/glm4.7/glm4.7-flash-qlora.yaml +++ b/examples/glm4.7/glm4.7-flash-qlora.yaml @@ -26,14 +26,16 @@ lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 lora_target_linear: true +# MoE expert weights (gate_up_proj, down_proj) are fused 3D tensors in this +# model and are NOT nn.Linear — target them via lora_target_parameters below. lora_target_modules: - - gate_proj - - down_proj - - up_proj - q_proj - v_proj - k_proj - o_proj +lora_target_parameters: + - mlp.experts.gate_up_proj + - mlp.experts.down_proj lora_mlp_kernel: true lora_qkv_kernel: true From 2fe6f408ae1e94c06f278e21d1e0ac01783f28b6 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Wed, 18 Feb 2026 09:47:50 +0530 Subject: [PATCH 17/19] rmv lora_qkv_kernel: false lora_o_kernel: false --- examples/glm4.7/glm4.7-flash-qlora.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/glm4.7/glm4.7-flash-qlora.yaml b/examples/glm4.7/glm4.7-flash-qlora.yaml index c8b44c79f7..979463b1f8 100644 --- a/examples/glm4.7/glm4.7-flash-qlora.yaml +++ b/examples/glm4.7/glm4.7-flash-qlora.yaml @@ -38,8 +38,8 @@ lora_target_parameters: - mlp.experts.down_proj lora_mlp_kernel: true -lora_qkv_kernel: true -lora_o_kernel: true +lora_qkv_kernel: false +lora_o_kernel: false gradient_accumulation_steps: 4 micro_batch_size: 1 From 91b5aa7c88e20276f40e85f8293432b282f70a11 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Thu, 19 Feb 2026 00:14:30 +0530 Subject: [PATCH 18/19] support lora _o_proj --- examples/glm4.7/glm4.7-flash-qlora.yaml | 2 +- src/axolotl/monkeypatch/lora_kernels.py | 71 ++++++++++++++----------- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/examples/glm4.7/glm4.7-flash-qlora.yaml b/examples/glm4.7/glm4.7-flash-qlora.yaml index 979463b1f8..d3799e0f69 100644 --- a/examples/glm4.7/glm4.7-flash-qlora.yaml +++ b/examples/glm4.7/glm4.7-flash-qlora.yaml @@ -39,7 +39,7 @@ lora_target_parameters: lora_mlp_kernel: true lora_qkv_kernel: false -lora_o_kernel: false +lora_o_kernel: true gradient_accumulation_steps: 4 micro_batch_size: 1 diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 2972c62851..c56a25226e 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -201,19 +201,18 @@ def patch_self_attn_lora(cfg: DictDefault): attention_cls._original_forward = self_attn_forward self_attn_forward, _ = detab_code(self_attn_forward) - assert any(qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES), ( - "Original QKV code not found" - ) - assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" - - for qkv_orig, qkv_patched in QKV_PATCHES: - if qkv_orig in self_attn_forward: - self_attn_forward = self_attn_forward.replace( - qkv_orig, - qkv_patched, - ) - break - self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) + if cfg.lora_qkv_kernel: + assert any(qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES), ( + "Original QKV code not found" + ) + for qkv_orig, qkv_patched in QKV_PATCHES: + if qkv_orig in self_attn_forward: + self_attn_forward = self_attn_forward.replace(qkv_orig, qkv_patched) + break + + if cfg.lora_o_kernel: + assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" + self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) self_attn_forward = self_attn_forward.replace( "def forward(", "def axolotl_attn_forward(", @@ -249,6 +248,12 @@ def find_self_attn_in_layer( for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] ): yield layer.self_attn + # MLA attention (DeepSeek-V2/V3, GLM-4.7): no q/k/v_proj, but o_proj is standard + elif all( + hasattr(layer.self_attn, proj) + for proj in ["kv_a_proj_with_mqa", "kv_b_proj", "o_proj"] + ): + yield layer.self_attn def find_mlp_in_layer( @@ -388,25 +393,31 @@ def apply_lora_kernel_patches( self_attn.apply_o = types.MethodType(original_apply_o, self_attn) if cfg.lora_qkv_kernel: - # Query, key, value patching - layer_modules = [ - getattr(self_attn, linear_proj) - for linear_proj in ["q_proj", "k_proj", "v_proj"] - ] - can_patch_qkv = all( - hasattr(module, "lora_A") - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 - for module in layer_modules - ) - - if can_patch_qkv: - # Add optimized implementation - self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) - else: + # Query, key, value patching — only for standard QKV models, not MLA + if not all(hasattr(self_attn, p) for p in ["q_proj", "k_proj", "v_proj"]): LOG.warning_once( - "Cannot patch some attention QKV projections - requires LoRA " - "adapters and no lora_magnitude_vector (DoRA)" + "Skipping QKV kernel patch — model uses MLA attention " + "(no q_proj/k_proj/v_proj). Disable lora_qkv_kernel to silence." + ) + else: + layer_modules = [ + getattr(self_attn, linear_proj) + for linear_proj in ["q_proj", "k_proj", "v_proj"] + ] + can_patch_qkv = all( + hasattr(module, "lora_A") + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + for module in layer_modules ) + + if can_patch_qkv: + # Add optimized implementation + self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) + else: + LOG.warning_once( + "Cannot patch some attention QKV projections - requires LoRA " + "adapters and no lora_magnitude_vector (DoRA)" + ) if cfg.lora_o_kernel: # Output patching layer_modules = [ From e0b7e93463858d3eaba0d642acd04ebb44980e81 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 20 Feb 2026 14:31:19 -0500 Subject: [PATCH 19/19] chore: lint --- src/axolotl/monkeypatch/lora_kernels.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index c56a25226e..38cf89403b 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -202,9 +202,9 @@ def patch_self_attn_lora(cfg: DictDefault): self_attn_forward, _ = detab_code(self_attn_forward) if cfg.lora_qkv_kernel: - assert any(qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES), ( - "Original QKV code not found" - ) + assert any( + qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES + ), "Original QKV code not found" for qkv_orig, qkv_patched in QKV_PATCHES: if qkv_orig in self_attn_forward: self_attn_forward = self_attn_forward.replace(qkv_orig, qkv_patched) @@ -394,7 +394,9 @@ def apply_lora_kernel_patches( if cfg.lora_qkv_kernel: # Query, key, value patching — only for standard QKV models, not MLA - if not all(hasattr(self_attn, p) for p in ["q_proj", "k_proj", "v_proj"]): + if not all( + hasattr(self_attn, p) for p in ["q_proj", "k_proj", "v_proj"] + ): LOG.warning_once( "Skipping QKV kernel patch — model uses MLA attention " "(no q_proj/k_proj/v_proj). Disable lora_qkv_kernel to silence." @@ -412,7 +414,9 @@ def apply_lora_kernel_patches( if can_patch_qkv: # Add optimized implementation - self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) + self_attn.apply_qkv = types.MethodType( + apply_lora_qkv, self_attn + ) else: LOG.warning_once( "Cannot patch some attention QKV projections - requires LoRA "