diff --git a/examples/glm4.7/glm4.7-flash-qlora.yaml b/examples/glm4.7/glm4.7-flash-qlora.yaml new file mode 100644 index 0000000000..d3799e0f69 --- /dev/null +++ b/examples/glm4.7/glm4.7-flash-qlora.yaml @@ -0,0 +1,67 @@ +base_model: zai-org/GLM-4.7-Flash + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true + +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora + +sequence_len: 32768 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 32 +lora_dropout: 0.0 +lora_target_linear: true +# MoE expert weights (gate_up_proj, down_proj) are fused 3D tensors in this +# model and are NOT nn.Linear — target them via lora_target_parameters below. +lora_target_modules: + - q_proj + - v_proj + - k_proj + - o_proj +lora_target_parameters: + - mlp.experts.gate_up_proj + - mlp.experts.down_proj + +lora_mlp_kernel: true +lora_qkv_kernel: false +lora_o_kernel: true + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 0.0001 +max_grad_norm: 1.0 + +bf16: auto + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 0 +saves_per_epoch: 4 +save_total_limit: 4 + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false + +deepspeed: diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 3b64b23db9..caad2e55e1 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -67,6 +67,35 @@ def find_all_linear_names(model): return list(lora_module_names) +def find_moe_expert_param_names(model: PreTrainedModel) -> list[str]: + """Detect 3D+ nn.Parameter tensors for PEFT target_parameters. + + In transformers v5, MoE models store expert weights as fused 3D nn.Parameter + tensors (num_experts, dim1, dim2) instead of individual nn.Linear modules. + PEFT's target_modules can't target these, but target_parameters can via the + ParamWrapper class which applies LoRA directly to nn.Parameter tensors. + + Returns a deduplicated list of parameter path suffixes (e.g., + ["mlp.experts.gate_up_proj", "mlp.experts.down_proj"]) suitable for + PEFT's LoraConfig target_parameters. + """ + seen_suffixes = set() + for name, param in model.named_parameters(): + if param.ndim >= 3 and any( + kw in name for kw in ("experts", "gate_up_proj", "down_proj") + ): + parts = name.split(".") + # Find the layer index (first numeric segment) and extract the + # repeating suffix after it. + # e.g. "model.layers.0.mlp.experts.gate_up_proj" -> "mlp.experts.gate_up_proj" + for i, part in enumerate(parts): + if part.isdigit(): + suffix = ".".join(parts[i + 1 :]) + seen_suffixes.add(suffix) + break + return sorted(seen_suffixes) + + def load_lora( model: PreTrainedModel, cfg: DictDefault, @@ -74,7 +103,29 @@ def load_lora( config_only: bool = False, ) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: lora_target_modules = cfg.lora_target_modules or [] - lora_target_parameters = cfg.lora_target_parameters or [] + + # Auto-detect MoE expert params for PEFT target_parameters (v5 3D nn.Parameter). + lora_target_parameters = cfg.lora_target_parameters + if lora_target_parameters is None: + detected_expert_params = getattr( + model, "_moe_expert_param_names", None + ) or find_moe_expert_param_names(model) + if detected_expert_params: + LOG.info( + "Auto-detected MoE expert parameters for LoRA target_parameters: %s", + detected_expert_params, + ) + lora_target_parameters = detected_expert_params + else: + lora_target_parameters = [] + elif isinstance(lora_target_parameters, str): + lora_target_parameters = [lora_target_parameters] + + # Exclude ParametrizationList submodules created by replace_parameter_4bit + # from target_modules matching (regex needed — "parametrizations" is mid-path). + exclude_modules = None + if getattr(model, "_moe_experts_quantized", False): + exclude_modules = r".*\.parametrizations\..*" if cfg.lora_target_linear: linear_names = find_all_linear_names(model) @@ -119,6 +170,7 @@ def load_lora( lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, target_parameters=lora_target_parameters, + exclude_modules=exclude_modules, layers_to_transform=cfg.peft_layers_to_transform, layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 6c88855268..b8186fb686 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -173,6 +173,20 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non PLUGIN_MANAGER.pre_model_load(self.cfg) self.patch_manager.apply_post_plugin_pre_model_load_patches() skip_move_to_device = self._build_model() + + # Quantize 3D MoE expert nn.Parameter tensors that BnB skips. + # Detect names before quantization (replace_parameter_4bit changes them). + self.model._moe_experts_quantized = False + self.model._moe_expert_param_names = [] + if self.cfg.adapter in ("qlora", "lora") and ( + self.cfg.load_in_4bit or self.cfg.load_in_8bit + ): + from axolotl.loaders.adapter import find_moe_expert_param_names + from axolotl.monkeypatch.moe_quant import quantize_moe_expert_params + + self.model._moe_expert_param_names = find_moe_expert_param_names(self.model) + self.model._moe_experts_quantized = quantize_moe_expert_params(self.model) + PLUGIN_MANAGER.post_model_build(self.cfg, self.model) # Post-build model configuration @@ -860,6 +874,10 @@ def _prepare_model_for_quantization(self): # Make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True + if getattr(self.model, "_moe_experts_quantized", False): + # Parametrized expert tensors dequantize on access — would OOM. + skip_prepare_model_for_kbit_training = True + if ( not skip_prepare_model_for_kbit_training and self.cfg.adapter in ["lora", "qlora"] diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 2972c62851..38cf89403b 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -201,19 +201,18 @@ def patch_self_attn_lora(cfg: DictDefault): attention_cls._original_forward = self_attn_forward self_attn_forward, _ = detab_code(self_attn_forward) - assert any(qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES), ( - "Original QKV code not found" - ) - assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" - - for qkv_orig, qkv_patched in QKV_PATCHES: - if qkv_orig in self_attn_forward: - self_attn_forward = self_attn_forward.replace( - qkv_orig, - qkv_patched, - ) - break - self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) + if cfg.lora_qkv_kernel: + assert any( + qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES + ), "Original QKV code not found" + for qkv_orig, qkv_patched in QKV_PATCHES: + if qkv_orig in self_attn_forward: + self_attn_forward = self_attn_forward.replace(qkv_orig, qkv_patched) + break + + if cfg.lora_o_kernel: + assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" + self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) self_attn_forward = self_attn_forward.replace( "def forward(", "def axolotl_attn_forward(", @@ -249,6 +248,12 @@ def find_self_attn_in_layer( for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] ): yield layer.self_attn + # MLA attention (DeepSeek-V2/V3, GLM-4.7): no q/k/v_proj, but o_proj is standard + elif all( + hasattr(layer.self_attn, proj) + for proj in ["kv_a_proj_with_mqa", "kv_b_proj", "o_proj"] + ): + yield layer.self_attn def find_mlp_in_layer( @@ -388,25 +393,35 @@ def apply_lora_kernel_patches( self_attn.apply_o = types.MethodType(original_apply_o, self_attn) if cfg.lora_qkv_kernel: - # Query, key, value patching - layer_modules = [ - getattr(self_attn, linear_proj) - for linear_proj in ["q_proj", "k_proj", "v_proj"] - ] - can_patch_qkv = all( - hasattr(module, "lora_A") - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 - for module in layer_modules - ) - - if can_patch_qkv: - # Add optimized implementation - self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) - else: + # Query, key, value patching — only for standard QKV models, not MLA + if not all( + hasattr(self_attn, p) for p in ["q_proj", "k_proj", "v_proj"] + ): LOG.warning_once( - "Cannot patch some attention QKV projections - requires LoRA " - "adapters and no lora_magnitude_vector (DoRA)" + "Skipping QKV kernel patch — model uses MLA attention " + "(no q_proj/k_proj/v_proj). Disable lora_qkv_kernel to silence." + ) + else: + layer_modules = [ + getattr(self_attn, linear_proj) + for linear_proj in ["q_proj", "k_proj", "v_proj"] + ] + can_patch_qkv = all( + hasattr(module, "lora_A") + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + for module in layer_modules ) + + if can_patch_qkv: + # Add optimized implementation + self_attn.apply_qkv = types.MethodType( + apply_lora_qkv, self_attn + ) + else: + LOG.warning_once( + "Cannot patch some attention QKV projections - requires LoRA " + "adapters and no lora_magnitude_vector (DoRA)" + ) if cfg.lora_o_kernel: # Output patching layer_modules = [ diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py new file mode 100644 index 0000000000..c9a03e221e --- /dev/null +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -0,0 +1,86 @@ +""" +Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors. + +In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter +tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets +nn.Linear, so these expert weights are skipped during model loading, causing OOM. + +This module provides a post-load fixup that quantizes those skipped parameters using +bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). +PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized +params via stacked parametrizations. +""" + +import bitsandbytes as bnb +import torch + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def find_unquantized_expert_params(model): + """Find 3D+ nn.Parameter tensors that BnB quantization skipped. + + Returns: + List of (module, param_name) tuples to quantize. + """ + params_to_quantize = [] + for _, module in model.named_modules(): + if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): + continue + for param_name, param in module.named_parameters(recurse=False): + if param.ndim >= 3 and any( + kw in param_name for kw in ("experts", "gate_up_proj", "down_proj") + ): + params_to_quantize.append((module, param_name)) + return params_to_quantize + + +def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None): + """Quantize 3D nn.Parameter expert weights that BnB skips during model loading. + + Reads quant_type and compress_statistics from the model's quantization_config + when not explicitly provided, so that the same settings used for nn.Linear + quantization are applied to the MoE expert parameters. + """ + from bitsandbytes.nn.parametrize import replace_parameter_4bit + + params_to_quantize = find_unquantized_expert_params(model) + if not params_to_quantize: + return False + + # Derive settings from model's BnB config if not explicitly provided + if quant_type is None or compress_statistics is None: + bnb_config = getattr(model.config, "quantization_config", None) + if bnb_config is not None: + if quant_type is None: + quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4") + if compress_statistics is None: + compress_statistics = getattr( + bnb_config, "bnb_4bit_use_double_quant", True + ) + # Final defaults + if quant_type is None: + quant_type = "nf4" + if compress_statistics is None: + compress_statistics = True + + count = 0 + for module, param_name in params_to_quantize: + replace_parameter_4bit( + module, + param_name, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + count += 1 + + torch.cuda.empty_cache() + LOG.info( + "Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)", + count, + quant_type, + compress_statistics, + ) + return True