-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
FIX: monkey patch bitsandbytes oom on v5 #3395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
292818e
8f81839
d8c0592
ce1b473
c69201e
38f7987
5a81c15
44eaef2
4d46469
1bab4c1
e97b14e
13d85b6
013d8f0
c3c6893
d16a853
82dad00
4d7da67
3da012f
41d30ff
dc7caa1
6c734e9
004ba8f
2fe6f40
91b5aa7
e0b7e93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| base_model: zai-org/GLM-4.7-Flash | ||
|
|
||
| plugins: | ||
| - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin | ||
|
|
||
| load_in_4bit: true | ||
|
|
||
| datasets: | ||
| - path: cgato/SlimOrcaDedupCleaned | ||
| type: chat_template | ||
| field_messages: conversations | ||
| message_property_mappings: | ||
| role: from | ||
| content: value | ||
|
|
||
| val_set_size: 0.0 | ||
| output_dir: ./outputs/out | ||
|
|
||
| adapter: qlora | ||
|
|
||
| sequence_len: 32768 | ||
| sample_packing: true | ||
| pad_to_sequence_len: true | ||
|
|
||
| lora_r: 32 | ||
| lora_alpha: 32 | ||
| lora_dropout: 0.0 | ||
| lora_target_linear: true | ||
| # MoE expert weights (gate_up_proj, down_proj) are fused 3D tensors in this | ||
| # model and are NOT nn.Linear — target them via lora_target_parameters below. | ||
| lora_target_modules: | ||
| - q_proj | ||
| - v_proj | ||
| - k_proj | ||
| - o_proj | ||
| lora_target_parameters: | ||
| - mlp.experts.gate_up_proj | ||
| - mlp.experts.down_proj | ||
|
|
||
| lora_mlp_kernel: true | ||
| lora_qkv_kernel: false | ||
| lora_o_kernel: true | ||
|
|
||
| gradient_accumulation_steps: 4 | ||
| micro_batch_size: 1 | ||
| num_epochs: 1 | ||
| optimizer: adamw_torch_fused | ||
| lr_scheduler: cosine | ||
| learning_rate: 0.0001 | ||
| max_grad_norm: 1.0 | ||
|
|
||
| bf16: auto | ||
|
|
||
| resume_from_checkpoint: | ||
| logging_steps: 1 | ||
| flash_attention: true | ||
|
|
||
| warmup_ratio: 0.1 | ||
| evals_per_epoch: 0 | ||
| saves_per_epoch: 4 | ||
| save_total_limit: 4 | ||
|
|
||
| gradient_checkpointing: true | ||
| gradient_checkpointing_kwargs: | ||
| use_reentrant: false | ||
|
|
||
| deepspeed: | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,14 +67,65 @@ def find_all_linear_names(model): | |
| return list(lora_module_names) | ||
|
|
||
|
|
||
| def find_moe_expert_param_names(model: PreTrainedModel) -> list[str]: | ||
| """Detect 3D+ nn.Parameter tensors for PEFT target_parameters. | ||
|
|
||
| In transformers v5, MoE models store expert weights as fused 3D nn.Parameter | ||
| tensors (num_experts, dim1, dim2) instead of individual nn.Linear modules. | ||
| PEFT's target_modules can't target these, but target_parameters can via the | ||
| ParamWrapper class which applies LoRA directly to nn.Parameter tensors. | ||
|
|
||
| Returns a deduplicated list of parameter path suffixes (e.g., | ||
| ["mlp.experts.gate_up_proj", "mlp.experts.down_proj"]) suitable for | ||
| PEFT's LoraConfig target_parameters. | ||
| """ | ||
| seen_suffixes = set() | ||
| for name, param in model.named_parameters(): | ||
| if param.ndim >= 3 and any( | ||
| kw in name for kw in ("experts", "gate_up_proj", "down_proj") | ||
| ): | ||
| parts = name.split(".") | ||
| # Find the layer index (first numeric segment) and extract the | ||
| # repeating suffix after it. | ||
| # e.g. "model.layers.0.mlp.experts.gate_up_proj" -> "mlp.experts.gate_up_proj" | ||
|
Comment on lines
+87
to
+90
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should have some checks for the word "experts" or gate_up_proj / down_proj. They seem to be the common names used. |
||
| for i, part in enumerate(parts): | ||
| if part.isdigit(): | ||
| suffix = ".".join(parts[i + 1 :]) | ||
| seen_suffixes.add(suffix) | ||
| break | ||
| return sorted(seen_suffixes) | ||
|
|
||
|
|
||
| def load_lora( | ||
| model: PreTrainedModel, | ||
| cfg: DictDefault, | ||
| inference: bool = False, | ||
| config_only: bool = False, | ||
| ) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: | ||
| lora_target_modules = cfg.lora_target_modules or [] | ||
| lora_target_parameters = cfg.lora_target_parameters or [] | ||
|
|
||
| # Auto-detect MoE expert params for PEFT target_parameters (v5 3D nn.Parameter). | ||
| lora_target_parameters = cfg.lora_target_parameters | ||
| if lora_target_parameters is None: | ||
| detected_expert_params = getattr( | ||
| model, "_moe_expert_param_names", None | ||
| ) or find_moe_expert_param_names(model) | ||
| if detected_expert_params: | ||
| LOG.info( | ||
| "Auto-detected MoE expert parameters for LoRA target_parameters: %s", | ||
| detected_expert_params, | ||
| ) | ||
| lora_target_parameters = detected_expert_params | ||
| else: | ||
| lora_target_parameters = [] | ||
| elif isinstance(lora_target_parameters, str): | ||
| lora_target_parameters = [lora_target_parameters] | ||
|
|
||
| # Exclude ParametrizationList submodules created by replace_parameter_4bit | ||
| # from target_modules matching (regex needed — "parametrizations" is mid-path). | ||
| exclude_modules = None | ||
| if getattr(model, "_moe_experts_quantized", False): | ||
| exclude_modules = r".*\.parametrizations\..*" | ||
|
|
||
| if cfg.lora_target_linear: | ||
| linear_names = find_all_linear_names(model) | ||
|
|
@@ -119,6 +170,7 @@ def load_lora( | |
| lora_alpha=cfg.lora_alpha, | ||
| target_modules=lora_target_modules, | ||
| target_parameters=lora_target_parameters, | ||
| exclude_modules=exclude_modules, | ||
| layers_to_transform=cfg.peft_layers_to_transform, | ||
| layers_pattern=cfg.peft_layers_pattern, | ||
| lora_dropout=cfg.lora_dropout, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -173,6 +173,20 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non | |
| PLUGIN_MANAGER.pre_model_load(self.cfg) | ||
| self.patch_manager.apply_post_plugin_pre_model_load_patches() | ||
| skip_move_to_device = self._build_model() | ||
|
|
||
| # Quantize 3D MoE expert nn.Parameter tensors that BnB skips. | ||
| # Detect names before quantization (replace_parameter_4bit changes them). | ||
| self.model._moe_experts_quantized = False | ||
| self.model._moe_expert_param_names = [] | ||
| if self.cfg.adapter in ("qlora", "lora") and ( | ||
| self.cfg.load_in_4bit or self.cfg.load_in_8bit | ||
| ): | ||
| from axolotl.loaders.adapter import find_moe_expert_param_names | ||
| from axolotl.monkeypatch.moe_quant import quantize_moe_expert_params | ||
|
|
||
| self.model._moe_expert_param_names = find_moe_expert_param_names(self.model) | ||
| self.model._moe_experts_quantized = quantize_moe_expert_params(self.model) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is called for LoRa, despite the inner function calling
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. narrowed the guard to load_in_4bit only now |
||
|
|
||
| PLUGIN_MANAGER.post_model_build(self.cfg, self.model) | ||
|
|
||
| # Post-build model configuration | ||
|
|
@@ -860,6 +874,10 @@ def _prepare_model_for_quantization(self): | |
| # Make sure everything is in the same dtype | ||
| skip_prepare_model_for_kbit_training = True | ||
|
|
||
| if getattr(self.model, "_moe_experts_quantized", False): | ||
| # Parametrized expert tensors dequantize on access — would OOM. | ||
| skip_prepare_model_for_kbit_training = True | ||
|
|
||
| if ( | ||
| not skip_prepare_model_for_kbit_training | ||
| and self.cfg.adapter in ["lora", "qlora"] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| """ | ||
| Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors. | ||
|
|
||
| In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter | ||
| tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets | ||
| nn.Linear, so these expert weights are skipped during model loading, causing OOM. | ||
|
|
||
| This module provides a post-load fixup that quantizes those skipped parameters using | ||
| bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). | ||
| PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized | ||
| params via stacked parametrizations. | ||
| """ | ||
|
|
||
| import bitsandbytes as bnb | ||
| import torch | ||
|
|
||
| from axolotl.utils.logging import get_logger | ||
|
|
||
| LOG = get_logger(__name__) | ||
|
|
||
|
|
||
| def find_unquantized_expert_params(model): | ||
| """Find 3D+ nn.Parameter tensors that BnB quantization skipped. | ||
|
|
||
| Returns: | ||
| List of (module, param_name) tuples to quantize. | ||
| """ | ||
| params_to_quantize = [] | ||
| for _, module in model.named_modules(): | ||
| if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): | ||
| continue | ||
| for param_name, param in module.named_parameters(recurse=False): | ||
| if param.ndim >= 3 and any( | ||
| kw in param_name for kw in ("experts", "gate_up_proj", "down_proj") | ||
| ): | ||
| params_to_quantize.append((module, param_name)) | ||
| return params_to_quantize | ||
|
|
||
|
|
||
| def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None): | ||
| """Quantize 3D nn.Parameter expert weights that BnB skips during model loading. | ||
|
|
||
| Reads quant_type and compress_statistics from the model's quantization_config | ||
| when not explicitly provided, so that the same settings used for nn.Linear | ||
| quantization are applied to the MoE expert parameters. | ||
| """ | ||
| from bitsandbytes.nn.parametrize import replace_parameter_4bit | ||
|
|
||
| params_to_quantize = find_unquantized_expert_params(model) | ||
| if not params_to_quantize: | ||
| return False | ||
|
|
||
| # Derive settings from model's BnB config if not explicitly provided | ||
| if quant_type is None or compress_statistics is None: | ||
| bnb_config = getattr(model.config, "quantization_config", None) | ||
| if bnb_config is not None: | ||
| if quant_type is None: | ||
| quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4") | ||
| if compress_statistics is None: | ||
| compress_statistics = getattr( | ||
| bnb_config, "bnb_4bit_use_double_quant", True | ||
| ) | ||
| # Final defaults | ||
| if quant_type is None: | ||
| quant_type = "nf4" | ||
| if compress_statistics is None: | ||
| compress_statistics = True | ||
|
|
||
| count = 0 | ||
| for module, param_name in params_to_quantize: | ||
| replace_parameter_4bit( | ||
| module, | ||
| param_name, | ||
| compress_statistics=compress_statistics, | ||
| quant_type=quant_type, | ||
| ) | ||
| count += 1 | ||
|
|
||
| torch.cuda.empty_cache() | ||
| LOG.info( | ||
| "Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)", | ||
| count, | ||
| quant_type, | ||
| compress_statistics, | ||
| ) | ||
| return True |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should explicitly set the lora target parameters so it's clear that it's being trained on here.
It doesn't seem possible to Not target those layers as well.It seems to always be on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done !