feat-qgalore#3654
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR adds complete support for the Q-GaLore optimizer to Axolotl by introducing utility functions for bitsandbytes compatibility and parameter grouping, extending the configuration schema with Q-GaLore-specific hyperparameters, validating incompatibilities and constraints, integrating the optimizer into the trainer builder, testing end-to-end training, and documenting the optimizer for users. ChangesQ-GaLore Optimizer Support
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/utils/optimizers/qgalore.py (1)
29-31: ⚡ Quick winConsider unpacking syntax for clearer tuple construction.
Static analysis suggests using unpacking syntax instead of tuple concatenation, which is more idiomatic and readable in Python.
♻️ Proposed refactor
optimizer_update_8bit_blockwise=( lambda *a, **kw: bw( - *(a[:7] + (0.0, 0.0) + a[7:] if len(a) == 15 else a), **kw + *((*a[:7], 0.0, 0.0, *a[7:]) if len(a) == 15 else a), **kw ) ), optimizer_update_32bit=( lambda *a, **kw: fp32( - *(a[:10] + (0.0, 0.0) + a[10:] if len(a) == 13 else a), **kw + *((*a[:10], 0.0, 0.0, *a[10:]) if len(a) == 13 else a), **kw ) ),As per coding guidelines, Ruff static analysis tool flagged RUF005: prefer unpacking over concatenation.
Also applies to: 34-36
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/utils/optimizers/qgalore.py` around lines 29 - 31, Replace the tuple-concatenation used in the wrapper lambda with Python unpacking for readability and to satisfy RUF005: instead of a[:7] + (0.0, 0.0) + a[7:], construct the args as (*a[:7], 0.0, 0.0, *a[7:]) inside the lambda that wraps bw (the anonymous lambda that calls bw(*(…), **kw)); apply the same unpacking change to the other similar wrapper occurrences around the bw calls referenced in the diff.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 916-954: In check_qgalore, add explicit validation to reject
quantized model-loading flags when optimizer == "q_galore_adamw8bit": check
data.get("load_in_8bit") and data.get("load_in_4bit") (and any equivalent keys
used elsewhere, e.g., "bnb_4bit") and raise a ValueError with a clear message
that q_galore_adamw8bit is incompatible with those settings; update the function
(check_qgalore) so these checks occur alongside the existing
adapter/deepspeed/fsdp checks before returning data.
---
Nitpick comments:
In `@src/axolotl/utils/optimizers/qgalore.py`:
- Around line 29-31: Replace the tuple-concatenation used in the wrapper lambda
with Python unpacking for readability and to satisfy RUF005: instead of a[:7] +
(0.0, 0.0) + a[7:], construct the args as (*a[:7], 0.0, 0.0, *a[7:]) inside the
lambda that wraps bw (the anonymous lambda that calls bw(*(…), **kw)); apply the
same unpacking change to the other similar wrapper occurrences around the bw
calls referenced in the diff.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f9165850-7c40-42c0-9b9a-5f607db7dda9
📒 Files selected for processing (8)
docs/optimizers.qmdpyproject.tomlsrc/axolotl/core/builders/base.pysrc/axolotl/utils/optimizers/qgalore.pysrc/axolotl/utils/schemas/training.pysrc/axolotl/utils/schemas/validation.pytests/e2e/test_optimizers.pytests/utils/schemas/validation/test_qgalore.py
| def check_qgalore(cls, data): | ||
| if data.get("optimizer") != "q_galore_adamw8bit": | ||
| return data | ||
| adapter = data.get("adapter") | ||
| if adapter: | ||
| raise ValueError( | ||
| "q_galore_adamw8bit operates on full-precision parameters and is " | ||
| f"incompatible with adapter='{adapter}'. Remove the adapter setting " | ||
| "or pick a different optimizer." | ||
| ) | ||
| if data.get("deepspeed"): | ||
| raise ValueError( | ||
| "q_galore_adamw8bit is not yet validated with DeepSpeed. " | ||
| "Use DDP or FSDP2 with use_orig_params=True." | ||
| ) | ||
| if data.get("fsdp") or data.get("fsdp_config"): | ||
| fsdp_version = cls._resolve_fsdp_version(data) | ||
| if str(fsdp_version) != "2": | ||
| raise ValueError( | ||
| "q_galore_adamw8bit requires FSDP2. Set fsdp_version: 2." | ||
| ) | ||
| fsdp_config = data.get("fsdp_config") or {} | ||
| if fsdp_config.get("use_orig_params") is not True: | ||
| raise ValueError( | ||
| "q_galore_adamw8bit requires fsdp_config.use_orig_params=True so " | ||
| "that per-parameter projection state survives FSDP sharding." | ||
| ) | ||
| if not (data.get("bf16") or data.get("bfloat16") or data.get("fp16")): | ||
| LOG.warning( | ||
| "q_galore_adamw8bit benefits from mixed-precision (bf16/fp16). " | ||
| "Running in fp32 will negate most of the memory savings." | ||
| ) | ||
| if data.get("optim_target_modules") is None: | ||
| # Match the reference impl's defaults: attention + MLP linears. | ||
| data["optim_target_modules"] = [ | ||
| "attn", | ||
| "mlp", | ||
| ] | ||
| return data |
There was a problem hiding this comment.
Missing validation for incompatible quantization settings.
The documentation (optimizers.qmd:144-146) states that q_galore_adamw8bit is incompatible with load_in_8bit and load_in_4bit, but the validator only checks for adapter and doesn't reject these quantization options. This could lead to runtime errors or undefined behavior.
🛡️ Proposed fix to add validation
`@classmethod`
def check_qgalore(cls, data):
if data.get("optimizer") != "q_galore_adamw8bit":
return data
adapter = data.get("adapter")
if adapter:
raise ValueError(
"q_galore_adamw8bit operates on full-precision parameters and is "
f"incompatible with adapter='{adapter}'. Remove the adapter setting "
"or pick a different optimizer."
)
+ if data.get("load_in_8bit"):
+ raise ValueError(
+ "q_galore_adamw8bit is incompatible with load_in_8bit. "
+ "Use full-precision model loading."
+ )
+ if data.get("load_in_4bit"):
+ raise ValueError(
+ "q_galore_adamw8bit is incompatible with load_in_4bit. "
+ "Use full-precision model loading."
+ )
if data.get("deepspeed"):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def check_qgalore(cls, data): | |
| if data.get("optimizer") != "q_galore_adamw8bit": | |
| return data | |
| adapter = data.get("adapter") | |
| if adapter: | |
| raise ValueError( | |
| "q_galore_adamw8bit operates on full-precision parameters and is " | |
| f"incompatible with adapter='{adapter}'. Remove the adapter setting " | |
| "or pick a different optimizer." | |
| ) | |
| if data.get("deepspeed"): | |
| raise ValueError( | |
| "q_galore_adamw8bit is not yet validated with DeepSpeed. " | |
| "Use DDP or FSDP2 with use_orig_params=True." | |
| ) | |
| if data.get("fsdp") or data.get("fsdp_config"): | |
| fsdp_version = cls._resolve_fsdp_version(data) | |
| if str(fsdp_version) != "2": | |
| raise ValueError( | |
| "q_galore_adamw8bit requires FSDP2. Set fsdp_version: 2." | |
| ) | |
| fsdp_config = data.get("fsdp_config") or {} | |
| if fsdp_config.get("use_orig_params") is not True: | |
| raise ValueError( | |
| "q_galore_adamw8bit requires fsdp_config.use_orig_params=True so " | |
| "that per-parameter projection state survives FSDP sharding." | |
| ) | |
| if not (data.get("bf16") or data.get("bfloat16") or data.get("fp16")): | |
| LOG.warning( | |
| "q_galore_adamw8bit benefits from mixed-precision (bf16/fp16). " | |
| "Running in fp32 will negate most of the memory savings." | |
| ) | |
| if data.get("optim_target_modules") is None: | |
| # Match the reference impl's defaults: attention + MLP linears. | |
| data["optim_target_modules"] = [ | |
| "attn", | |
| "mlp", | |
| ] | |
| return data | |
| def check_qgalore(cls, data): | |
| if data.get("optimizer") != "q_galore_adamw8bit": | |
| return data | |
| adapter = data.get("adapter") | |
| if adapter: | |
| raise ValueError( | |
| "q_galore_adamw8bit operates on full-precision parameters and is " | |
| f"incompatible with adapter='{adapter}'. Remove the adapter setting " | |
| "or pick a different optimizer." | |
| ) | |
| if data.get("load_in_8bit"): | |
| raise ValueError( | |
| "q_galore_adamw8bit is incompatible with load_in_8bit. " | |
| "Use full-precision model loading." | |
| ) | |
| if data.get("load_in_4bit"): | |
| raise ValueError( | |
| "q_galore_adamw8bit is incompatible with load_in_4bit. " | |
| "Use full-precision model loading." | |
| ) | |
| if data.get("deepspeed"): | |
| raise ValueError( | |
| "q_galore_adamw8bit is not yet validated with DeepSpeed. " | |
| "Use DDP or FSDP2 with use_orig_params=True." | |
| ) | |
| if data.get("fsdp") or data.get("fsdp_config"): | |
| fsdp_version = cls._resolve_fsdp_version(data) | |
| if str(fsdp_version) != "2": | |
| raise ValueError( | |
| "q_galore_adamw8bit requires FSDP2. Set fsdp_version: 2." | |
| ) | |
| fsdp_config = data.get("fsdp_config") or {} | |
| if fsdp_config.get("use_orig_params") is not True: | |
| raise ValueError( | |
| "q_galore_adamw8bit requires fsdp_config.use_orig_params=True so " | |
| "that per-parameter projection state survives FSDP sharding." | |
| ) | |
| if not (data.get("bf16") or data.get("bfloat16") or data.get("fp16")): | |
| LOG.warning( | |
| "q_galore_adamw8bit benefits from mixed-precision (bf16/fp16). " | |
| "Running in fp32 will negate most of the memory savings." | |
| ) | |
| if data.get("optim_target_modules") is None: | |
| # Match the reference impl's defaults: attention + MLP linears. | |
| data["optim_target_modules"] = [ | |
| "attn", | |
| "mlp", | |
| ] | |
| return data |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/utils/schemas/validation.py` around lines 916 - 954, In
check_qgalore, add explicit validation to reject quantized model-loading flags
when optimizer == "q_galore_adamw8bit": check data.get("load_in_8bit") and
data.get("load_in_4bit") (and any equivalent keys used elsewhere, e.g.,
"bnb_4bit") and raise a ValueError with a clear message that q_galore_adamw8bit
is incompatible with those settings; update the function (check_qgalore) so
these checks occur alongside the existing adapter/deepspeed/fsdp checks before
returning data.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Description
feat-qgalore
https://arxiv.org/pdf/2407.08296
Motivation and Context
#1752
How has this been tested?
unit test + manual run
AI Usage Disclaimer
claude opus helped with ideation and testing
Summary by CodeRabbit
Release Notes
New Features
q_galore_adamw8bit) with configurable rank, projection, and quantization parameters for memory-efficient fine-tuning.Documentation