From d1ebe6e480aca93164ec54a5098569f71c9c817a Mon Sep 17 00:00:00 2001 From: Avaya Aggarwal Date: Sat, 21 Mar 2026 23:29:24 +0530 Subject: [PATCH 01/12] feat: Implement Q-GaLore optimizer and custom embedding learning rate in the Unsloth trainer. --- tests/utils/test_q_galore.py | 317 +++++++++++++++++++ unsloth/optimizers/__init__.py | 21 ++ unsloth/optimizers/q_galore_adamw.py | 370 +++++++++++++++++++++++ unsloth/optimizers/q_galore_projector.py | 362 ++++++++++++++++++++++ unsloth/trainer.py | 94 +++++- 5 files changed, 1162 insertions(+), 2 deletions(-) create mode 100644 tests/utils/test_q_galore.py create mode 100644 unsloth/optimizers/__init__.py create mode 100644 unsloth/optimizers/q_galore_adamw.py create mode 100644 unsloth/optimizers/q_galore_projector.py diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py new file mode 100644 index 0000000000..24a159f6d1 --- /dev/null +++ b/tests/utils/test_q_galore.py @@ -0,0 +1,317 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Tests for Q-GaLore integration (unsloth/optimizers/). + +import pytest +import sys +import os +import torch +import torch.nn as nn + +# Import the optimizers module directly to avoid triggering unsloth.__init__ +# which requires unsloth_zoo and other heavy dependencies. +_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +_optimizers_dir = os.path.join(_repo_root, "unsloth", "optimizers") +if _repo_root not in sys.path: + sys.path.insert(0, _repo_root) + +# Direct import of the actual modules (avoids unsloth/__init__.py) +import importlib.util + +def _load_module(name, filepath): + spec = importlib.util.spec_from_file_location(name, filepath) + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + +# Load projector module first (no dependencies on unsloth) +_projector_mod = _load_module( + "unsloth.optimizers.q_galore_projector", + os.path.join(_optimizers_dir, "q_galore_projector.py"), +) +GaLoreProjector = _projector_mod.GaLoreProjector +_quantize = _projector_mod._quantize +_dequantize = _projector_mod._dequantize +_quantize_stochastic = _projector_mod._quantize_stochastic + +# Load adamw module (depends on projector, may skip bitsandbytes) +_adamw_mod = _load_module( + "unsloth.optimizers.q_galore_adamw", + os.path.join(_optimizers_dir, "q_galore_adamw.py"), +) +make_q_galore_param_groups = _adamw_mod.make_q_galore_param_groups + +# ====================================================================== +# Projector tests +# ====================================================================== + + +class TestGaLoreProjector: + """Tests for the GaLore low-rank gradient projector.""" + + def test_project_and_back_tall(self): + """Project → project_back preserves shape for tall matrices.""" + proj = GaLoreProjector(rank=4, update_proj_gap=1) + grad = torch.randn(16, 8) # tall + low = proj.project(grad, step=0) + assert low.shape == (16, 4) + + full = proj.project_back(low) + assert full.shape == grad.shape + + def test_project_and_back_wide(self): + """Project → project_back preserves shape for wide matrices.""" + proj = GaLoreProjector(rank=4, update_proj_gap=1) + grad = torch.randn(8, 16) # wide + low = proj.project(grad, step=0) + assert low.shape == (4, 16) + + full = proj.project_back(low) + assert full.shape == grad.shape + + def test_project_reuses_cached_svd(self): + """SVD is not recomputed when step is not a multiple of update_proj_gap.""" + proj = GaLoreProjector(rank=4, update_proj_gap=100) + grad = torch.randn(16, 8) + proj.project(grad, step=0) + assert proj.svd_count == 1 + + proj.project(grad, step=1) + assert proj.svd_count == 1 # No recomputation + + proj.project(grad, step=100) + assert proj.svd_count == 2 # Recomputed + + def test_quantized_projection(self): + """Quantized projection matrix stores and restores with bounded error.""" + proj = GaLoreProjector(rank=4, update_proj_gap=1, quant=True, n_bit=8) + grad = torch.randn(16, 8) + low = proj.project(grad, step=0) + assert low.shape == (16, 4) + + # The projection matrix should be stored as uint8 + assert proj.ortho_matrix.dtype == torch.uint8 + + def test_quantized_projection_int4(self): + """INT4 quantized projection stores correctly.""" + proj = GaLoreProjector(rank=4, update_proj_gap=1, quant=True, n_bit=4) + grad = torch.randn(16, 8) + proj.project(grad, step=0) + assert proj.ortho_matrix.dtype == torch.uint8 + # INT4 values should be in range [0, 15] + assert proj.ortho_matrix.max() <= 15 + + def test_adaptive_scheduling(self): + """update_proj_gap increases when cosine similarity exceeds threshold.""" + proj = GaLoreProjector( + rank=4, + update_proj_gap=10, + cos_threshold=0.0, # Very low threshold → always triggers + gamma_proj=2.0, + queue_size=2, + ) + # Use very similar gradients so cosine similarity is high + base_grad = torch.randn(16, 8) + for i in range(5): + grad = base_grad + torch.randn_like(base_grad) * 0.001 + proj.project(grad, step=i * 10) + + # After several similar SVDs, update_proj_gap should have increased + assert proj.update_proj_gap > 10 + + def test_scale_applied(self): + """project_back applies the scale factor.""" + proj = GaLoreProjector(rank=4, update_proj_gap=1, scale=0.5) + grad = torch.randn(16, 8) + low = proj.project(grad, step=0) + + proj2 = GaLoreProjector(rank=4, update_proj_gap=1, scale=1.0) + low2 = proj2.project(grad, step=0) + + full_half = proj.project_back(low) + full_one = proj2.project_back(low2) + + # The ratio should be approximately 0.5 + ratio = full_half.norm() / full_one.norm() + assert abs(ratio - 0.5) < 0.15 + + +# ====================================================================== +# Quantization utility tests +# ====================================================================== + + +class TestQuantizationUtils: + """Tests for _quantize, _dequantize, _quantize_stochastic.""" + + def test_quantize_dequantize_roundtrip(self): + """Quantize → dequantize has bounded error.""" + w = torch.randn(32, 64) + q, scales, zeros, shape = _quantize(w, n_bit=8) + w_hat = _dequantize(q, scales, zeros, shape) + + # Error should be bounded by the quantization step size + error = (w - w_hat).abs().max() + assert error < 0.1, f"Max error {error} exceeds threshold" + + def test_quantize_group_roundtrip(self): + """Grouped quantization → dequantization has bounded error.""" + w = torch.randn(32, 64) + q, scales, zeros, shape = _quantize(w, q_group_size=32, n_bit=8) + w_hat = _dequantize(q, scales, zeros, shape) + error = (w - w_hat).abs().max() + assert error < 0.1 + + def test_quantize_dtype(self): + """Quantized output should be uint8.""" + w = torch.randn(16, 16) + q, _, _, _ = _quantize(w, n_bit=8) + assert q.dtype == torch.uint8 + + def test_quantize_int4_range(self): + """INT4 values should be in [0, 15].""" + w = torch.randn(16, 16) + q, _, _, _ = _quantize(w, n_bit=4) + assert q.max() <= 15 + assert q.min() >= 0 + + def test_stochastic_rounding_unbiased(self): + """Stochastic rounding should be approximately unbiased.""" + torch.manual_seed(42) + w = torch.randn(64, 64) + errors = [] + for _ in range(50): + q, scales, zeros, shape = _quantize_stochastic(w, n_bit=8) + w_hat = _dequantize(q, scales, zeros, shape) + errors.append((w - w_hat).mean().item()) + + mean_error = sum(errors) / len(errors) + assert abs(mean_error) < 0.01, ( + f"Mean error {mean_error} suggests biased rounding" + ) + + +# ====================================================================== +# Param group helper tests +# ====================================================================== + + +class TestParamGroupHelper: + """Tests for make_q_galore_param_groups.""" + + def test_param_group_separation(self): + """GaLore vs non-GaLore params are correctly separated.""" + + # Create a mini-transformer-like model + model = nn.Module() + model.q_proj = nn.Linear(64, 64, bias=False) + model.k_proj = nn.Linear(64, 64, bias=False) + model.embed = nn.Embedding(100, 64) + model.norm = nn.LayerNorm(64) + + groups = make_q_galore_param_groups(model, rank=8, weight_quant=False) + + # Should have 2 groups: galore and non-galore + assert len(groups) == 2 + + galore_group = [g for g in groups if "rank" in g][0] + non_galore_group = [g for g in groups if "rank" not in g][0] + + # q_proj and k_proj should be in galore group (2 params) + assert len(galore_group["params"]) == 2 + # embed and norm should be in non-galore group + assert len(non_galore_group["params"]) == 3 # embed weight + norm weight + norm bias + + def test_custom_target_modules(self): + """Custom target_modules narrows GaLore scope.""" + + model = nn.Module() + model.q_proj = nn.Linear(64, 64, bias=False) + model.k_proj = nn.Linear(64, 64, bias=False) + model.v_proj = nn.Linear(64, 64, bias=False) + model.embed = nn.Embedding(100, 64) + + groups = make_q_galore_param_groups( + model, rank=8, target_modules=["q_proj"], weight_quant=False, + ) + + galore_group = [g for g in groups if "rank" in g][0] + assert len(galore_group["params"]) == 1 # Only q_proj + + +# ====================================================================== +# Optimizer tests (CPU-only, no bitsandbytes dependency) +# ====================================================================== + + +class TestQGaLoreIntegration: + """Integration tests that work without bitsandbytes on CPU.""" + + def test_projector_training_loop(self): + """A simple training loop using manual GaLore projection converges.""" + torch.manual_seed(42) + + # Tiny model: single linear layer + model = nn.Linear(32, 16, bias=False) + target = torch.randn(4, 16) + x = torch.randn(4, 32) + + proj = GaLoreProjector(rank=8, update_proj_gap=1, scale=1.0) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) + + losses = [] + for step in range(20): + optimizer.zero_grad() + out = model(x) + loss = nn.functional.mse_loss(out, target) + loss.backward() + losses.append(loss.item()) + + # Manual GaLore projection + for p in model.parameters(): + if p.grad is not None and p.grad.dim() == 2: + low = proj.project(p.grad, step) + p._saved = p.data.clone() + update = torch.zeros_like(low) + update.add_(low) # Simplified update + full_update = proj.project_back(update) + p.grad.copy_(full_update) + + optimizer.step() + + # Loss should decrease + assert losses[-1] < losses[0], ( + f"Loss did not decrease: {losses[0]:.4f} → {losses[-1]:.4f}" + ) + + def test_full_projector_roundtrip_quality(self): + """project → project_back captures the dominant gradient directions.""" + torch.manual_seed(42) + # Create a gradient with clear low-rank structure + u = torch.randn(32, 4) + v = torch.randn(4, 16) + grad = u @ v # rank-4 gradient + + proj = GaLoreProjector(rank=4, update_proj_gap=1, scale=1.0) + low = proj.project(grad, step=0) + reconstructed = proj.project_back(low) + + # For a rank-4 gradient with rank-4 projection, reconstruction + # should be very close to original + relative_error = (grad - reconstructed).norm() / grad.norm() + assert relative_error < 0.05, ( + f"Reconstruction error too high: {relative_error:.4f}" + ) diff --git a/unsloth/optimizers/__init__.py b/unsloth/optimizers/__init__.py new file mode 100644 index 0000000000..b126321ab1 --- /dev/null +++ b/unsloth/optimizers/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .q_galore_projector import GaLoreProjector +from .q_galore_adamw import QGaLoreAdamW8bit + +__all__ = [ + "GaLoreProjector", + "QGaLoreAdamW8bit", +] diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py new file mode 100644 index 0000000000..85b933e8e9 --- /dev/null +++ b/unsloth/optimizers/q_galore_adamw.py @@ -0,0 +1,370 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from Q-GaLore (https://github.com/VITA-Group/Q-GaLore) +# Original paper: "Q-GaLore: Quantized GaLore with INT4 Projection and +# Layer-Adaptive Low-Rank Gradients" (arXiv:2407.08296) + +import torch +from collections import defaultdict +from typing import Optional, List + +from .q_galore_projector import ( + GaLoreProjector, + _quantize, + _quantize_stochastic, + _dequantize, +) + +__all__ = ["QGaLoreAdamW8bit"] + +try: + import bitsandbytes.functional as bnb_F + from bitsandbytes.optim.optimizer import Optimizer2State + + _HAS_BNB = True +except ImportError: + _HAS_BNB = False + # Provide a fallback base so the module can at least be imported. + Optimizer2State = torch.optim.Optimizer + + +def _require_bnb(): + if not _HAS_BNB: + raise ImportError( + "Unsloth: Q-GaLore requires bitsandbytes. " + "Install it with: pip install bitsandbytes" + ) + + +class QGaLoreAdamW8bit(Optimizer2State): + """AdamW optimizer with 8-bit states, GaLore low-rank gradient projection, + and optional INT8 weight quantization. + + This optimizer combines three memory-saving techniques: + + 1. **8-bit optimizer states** (via bitsandbytes) — Adam's first and second + moments are stored in 8-bit, reducing optimizer state memory by ~4×. + + 2. **GaLore low-rank gradient projection** — gradients are projected into a + low-rank subspace before the optimizer step, then projected back. The + projection matrix itself can be quantized to INT4. + + 3. **INT8 weight quantization** — model weights are stored in INT8 during + training with stochastic rounding, reducing weight memory by ~2× for + eligible layers. + + Param group keys consumed by GaLore projection: + ``rank``, ``update_proj_gap``, ``scale``, ``proj_type``, + ``quant`` (projection quantization), ``quant_group_size``, + ``quant_n_bit``, ``cos_threshold``, ``gamma_proj``, ``queue_size`` + + Param group keys for weight quantization: + ``weight_quant``, ``stochastic_round``, ``weight_group_size`` + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + min_8bit_size: int = 4096, + percentile_clipping: int = 100, + block_wise: bool = True, + is_paged: bool = False, + ): + _require_bnb() + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, # optim_bits + None, # args + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + + # ------------------------------------------------------------------ + # Core step + # ------------------------------------------------------------------ + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step. + + For each parameter that has a ``rank`` key in its param group, the + following sequence is executed: + + 1. If ``weight_quant`` is set, dequantize the INT8 weight to float. + 2. Project the gradient to low-rank via the cached ``GaLoreProjector``. + 3. Perform the 8-bit Adam update in the low-rank space. + 4. Project the update back to full rank and add to saved weight. + 5. If ``weight_quant`` is set, re-quantize the weight to INT8. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self.initialized: + self.check_overrides() + self.to_gpu() + self.initialized = True + + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group["params"]): + if p.grad is None: + continue + + state = self.state[p] + if "step" not in state: + state["step"] = 0 + + has_weight_quant = self._has_weight_quant(p, group) + + # --- Dequantize weight if INT8 --- + if has_weight_quant: + float_weight = _dequantize( + p.data, p._q_scales, p._q_zeros, p._q_shape, + ) + p.data = float_weight.clone().to(p.data.device) + + # --- GaLore projection --- + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + rank=group["rank"], + update_proj_gap=group.get("update_proj_gap", 200), + scale=group.get("scale", 0.25), + proj_type=group.get("proj_type", "std"), + quant=group.get("quant", False), + group_size=group.get("quant_group_size", -1), + n_bit=group.get("quant_n_bit", 4), + cos_threshold=group.get("cos_threshold", 0.4), + gamma_proj=group.get("gamma_proj", 2.0), + queue_size=group.get("queue_size", 5), + ) + + # Temporarily disable weight decay for GaLore params + # (we apply it manually after project-back) + if "weight_decay" in group and group["weight_decay"] > 0: + group["_wd_saved"] = group["weight_decay"] + group["weight_decay"] = 0 + + grad = state["projector"].project(p.grad, state["step"]) + + # Save current weight; replace p.data with zeros so + # the 8-bit update writes the pure weight delta. + p._saved_data = p.data.clone() + p.data = torch.zeros_like(grad, dtype=p.data.dtype, device=p.data.device) + p.grad = grad + + # --- 8-bit Adam update --- + if "state1" not in state: + self.init_state(group, p, gindex, pindex) + + self.prefetch_state(p) + self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + + # --- GaLore project-back --- + if "rank" in group: + # p.data now holds the weight update in low-rank space + p.data = p._saved_data.add_( + state["projector"].project_back(p.data) + ) + del p._saved_data + + # Re-apply weight decay + if "_wd_saved" in group: + p.data.add_( + p.data, + alpha=-group["lr"] * group["_wd_saved"], + ) + group["weight_decay"] = group["_wd_saved"] + del group["_wd_saved"] + + # --- Re-quantize weight to INT8 --- + if has_weight_quant: + saved = p.data.clone() + stochastic = group.get("stochastic_round", True) + gsize = group.get("weight_group_size", 128) + quant_fn = _quantize_stochastic if stochastic else _quantize + q, scales, zeros, shape = quant_fn(saved, q_group_size=gsize) + p.data = q.to(p.data.device) + p._q_scales = scales + p._q_zeros = zeros + p._q_shape = shape + + state["step"] += 1 + + if self.is_paged: + torch.cuda.synchronize() + + return loss + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _has_weight_quant(p: torch.Tensor, group: dict) -> bool: + """Check if this parameter uses INT8 weight quantization.""" + return group.get("weight_quant", False) and hasattr(p, "_q_scales") + + @staticmethod + def init_weight_quantization( + model: torch.nn.Module, + param_groups: list, + group_size: int = 128, + stochastic: bool = True, + ) -> None: + """Initialize INT8 weight quantization for params in groups with + ``weight_quant=True``. + + This should be called once before the first optimizer step. It + quantizes eligible weights in-place and stores the quantization + metadata (scales, zeros, shape) as attributes on the parameter tensor. + """ + weight_quant_params = set() + for group in param_groups: + if group.get("weight_quant", False): + for p in group["params"]: + weight_quant_params.add(id(p)) + + for name, p in model.named_parameters(): + if id(p) in weight_quant_params: + quant_fn = _quantize_stochastic if stochastic else _quantize + q, scales, zeros, shape = quant_fn( + p.data, q_group_size=group_size, + ) + p.data = q.to(p.data.device) + p._q_scales = scales + p._q_zeros = zeros + p._q_shape = shape + p._stochastic_round = stochastic + p._weight_group_size = group_size + + +# ====================================================================== +# Param-group construction helper +# ====================================================================== + +# Default linear layer names in transformer blocks that should use GaLore. +_DEFAULT_GALORE_TARGETS = { + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", +} + + +def make_q_galore_param_groups( + model: torch.nn.Module, + lr: float = 1e-3, + weight_decay: float = 0.0, + rank: int = 256, + update_proj_gap: int = 200, + scale: float = 0.25, + proj_quant: bool = True, + proj_quant_group_size: int = -1, + proj_quant_n_bit: int = 4, + weight_quant: bool = True, + stochastic_round: bool = True, + weight_group_size: int = 128, + cos_threshold: float = 0.4, + gamma_proj: float = 2.0, + queue_size: int = 5, + target_modules: Optional[List[str]] = None, +) -> list: + """Build param groups suitable for :class:`QGaLoreAdamW8bit`. + + Parameters matching ``target_modules`` (or the default set of attention + and MLP projection names) are placed in the GaLore group. All other + trainable parameters go into the non-GaLore group. + + Args: + model: The model whose parameters to partition. + lr: Learning rate for all parameter groups. + weight_decay: Weight decay coefficient. + rank: GaLore projection rank. + update_proj_gap: Steps between SVD recomputations. + scale: Scaling factor for project-back. + proj_quant: Quantize projection matrices. + proj_quant_group_size: Group size for projection quantization. + proj_quant_n_bit: Bit-width for projection quantization. + weight_quant: Enable INT8 weight quantization for GaLore params. + stochastic_round: Use stochastic rounding for weight quantization. + weight_group_size: Group size for weight quantization. + cos_threshold: Cosine similarity threshold for adaptive scheduling. + gamma_proj: Multiplier for update_proj_gap when subspace is stable. + queue_size: Rolling window size for stability tracking. + target_modules: Module name substrings to match for GaLore. If None, + uses the default set of attention/MLP projection names. + + Returns: + List of two param group dicts: ``[galore_group, non_galore_group]``. + """ + targets = set(target_modules) if target_modules else _DEFAULT_GALORE_TARGETS + + galore_params = [] + non_galore_params = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # Check if any target module name appears as a component in the param name + is_galore = any(t in name for t in targets) + + if is_galore: + galore_params.append(param) + else: + non_galore_params.append(param) + + groups = [] + + if galore_params: + groups.append({ + "params": galore_params, + "lr": lr, + "weight_decay": weight_decay, + "rank": rank, + "update_proj_gap": update_proj_gap, + "scale": scale, + "proj_type": "std", + "quant": proj_quant, + "quant_group_size": proj_quant_group_size, + "quant_n_bit": proj_quant_n_bit, + "weight_quant": weight_quant, + "stochastic_round": stochastic_round, + "weight_group_size": weight_group_size, + "cos_threshold": cos_threshold, + "gamma_proj": gamma_proj, + "queue_size": queue_size, + }) + + if non_galore_params: + groups.append({ + "params": non_galore_params, + "lr": lr, + "weight_decay": weight_decay, + }) + + return groups diff --git a/unsloth/optimizers/q_galore_projector.py b/unsloth/optimizers/q_galore_projector.py new file mode 100644 index 0000000000..00c6bade26 --- /dev/null +++ b/unsloth/optimizers/q_galore_projector.py @@ -0,0 +1,362 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from Q-GaLore (https://github.com/VITA-Group/Q-GaLore) +# Original paper: "Q-GaLore: Quantized GaLore with INT4 Projection and +# Layer-Adaptive Low-Rank Gradients" (arXiv:2407.08296) + +import torch +import torch.nn.functional as F + +__all__ = ["GaLoreProjector"] + + +class GaLoreProjector: + """Low-rank gradient projector with optional INT4/INT8 quantized projection + matrices and layer-adaptive subspace update scheduling. + + The projector computes an SVD of the gradient to obtain an orthogonal basis + for the top-``rank`` subspace. Gradients are projected into this subspace + for the optimizer step, then projected back to full rank for the weight + update. + + Two key Q-GaLore innovations are implemented: + + 1. **Quantized projection matrices** — when ``quant=True``, the orthogonal + matrix is stored in INT4/INT8, reducing the memory cost of keeping the + projector state. + + 2. **Layer-adaptive update scheduling** — a rolling queue of cosine + similarities between consecutive orthogonal vectors is maintained. When + the average exceeds ``cos_threshold``, ``update_proj_gap`` is multiplied + by ``gamma_proj``, effectively reducing the frequency of expensive SVD + recomputations for layers whose subspace has stabilized. + + Args: + rank: Target rank for the low-rank projection. + update_proj_gap: Number of steps between SVD recomputations. + scale: Scaling factor applied when projecting back to full rank. + proj_type: Projection type. Only ``'std'`` is supported. + quant: Whether to quantize the projection matrix. + group_size: Group size for projection matrix quantization. + n_bit: Bit-width for projection matrix quantization (4 or 8). + cos_threshold: Cosine similarity threshold for adaptive scheduling. + gamma_proj: Multiplier for ``update_proj_gap`` on stability detection. + queue_size: Number of recent cosine similarities to average. + """ + + __slots__ = ( + "rank", "update_proj_gap", "scale", "proj_type", + "quant", "quant_group_size", "quant_n_bit", + "cos_threshold", "gamma_proj", "queue_size", + "ortho_matrix", "ortho_matrix_scales", "ortho_matrix_zeros", + "ortho_matrix_shape", "past_ortho_vector", "queue", "svd_count", + ) + + def __init__( + self, + rank: int, + update_proj_gap: int = 200, + scale: float = 1.0, + proj_type: str = "std", + quant: bool = False, + group_size: int = -1, + n_bit: int = 4, + cos_threshold: float = 0.4, + gamma_proj: float = 2.0, + queue_size: int = 5, + ): + self.rank = rank + self.update_proj_gap = update_proj_gap + self.scale = scale + self.proj_type = proj_type + + # Quantization settings for the projection matrix + self.quant = quant + self.quant_group_size = group_size + self.quant_n_bit = n_bit + + # Adaptive update scheduling state + self.cos_threshold = cos_threshold + self.gamma_proj = gamma_proj + self.queue_size = queue_size + self.past_ortho_vector = None + self.queue: list = [] + self.svd_count = 0 + + # Projection matrix state + self.ortho_matrix = None + self.ortho_matrix_scales = None + self.ortho_matrix_zeros = None + self.ortho_matrix_shape = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def project(self, full_rank_grad: torch.Tensor, step: int) -> torch.Tensor: + """Project a full-rank gradient into the low-rank subspace. + + The SVD is recomputed every ``update_proj_gap`` steps (subject to + adaptive scheduling). Between recomputations the cached orthogonal + matrix is reused. + + Args: + full_rank_grad: The full-rank gradient tensor (2-D). + step: The current optimizer step (0-indexed). + + Returns: + The low-rank gradient tensor. + """ + assert self.proj_type == "std", "Only proj_type='std' is supported." + + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + # "tall" matrix → right projection (grad @ Q^T) + if self.ortho_matrix is None or step % self.update_proj_gap == 0: + float_ortho = self._compute_orthogonal( + full_rank_grad, self.rank, side="right", + ) + self._update_adaptive_schedule(float_ortho, side="right") + self._store_ortho(float_ortho) + + float_ortho = self._load_ortho() + low_rank_grad = torch.matmul(full_rank_grad, float_ortho.t()) + else: + # "wide" matrix → left projection (Q^T @ grad) + if self.ortho_matrix is None or step % self.update_proj_gap == 0: + float_ortho = self._compute_orthogonal( + full_rank_grad, self.rank, side="left", + ) + self._update_adaptive_schedule(float_ortho, side="left") + self._store_ortho(float_ortho) + + float_ortho = self._load_ortho() + low_rank_grad = torch.matmul(float_ortho.t(), full_rank_grad) + + return low_rank_grad + + def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: + """Project a low-rank update back to full rank. + + Args: + low_rank_grad: The low-rank gradient/update tensor. + + Returns: + The full-rank update scaled by ``self.scale``. + """ + float_ortho = self._load_ortho() + + if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(low_rank_grad, float_ortho) + else: + full_rank_grad = torch.matmul(float_ortho, low_rank_grad) + + return full_rank_grad * self.scale + + # ------------------------------------------------------------------ + # SVD + # ------------------------------------------------------------------ + + @staticmethod + def _compute_orthogonal( + weights: torch.Tensor, rank: int, side: str, + ) -> torch.Tensor: + """Compute the top-``rank`` orthogonal matrix via truncated SVD. + + Args: + weights: 2-D tensor (typically the gradient). + rank: Number of singular vectors to keep. + side: ``'left'`` returns U[:, :rank], ``'right'`` returns Vh[:rank, :]. + + Returns: + Orthogonal matrix of shape ``(rank, N)`` (right) or ``(M, rank)`` (left). + """ + original_dtype = weights.dtype + original_device = weights.device + + matrix = weights.float() if original_dtype != torch.float32 else weights + + U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + + if side == "right": + result = Vh[:rank, :] + elif side == "left": + result = U[:, :rank] + else: + raise ValueError(f"side must be 'left' or 'right', got '{side}'") + + if original_dtype != torch.float32: + result = result.to(device=original_device, dtype=original_dtype) + return result + + # ------------------------------------------------------------------ + # Adaptive scheduling + # ------------------------------------------------------------------ + + def _update_adaptive_schedule( + self, float_ortho: torch.Tensor, side: str, + ) -> None: + """Track subspace stability and increase ``update_proj_gap`` if stable.""" + self.svd_count += 1 + + if side == "right": + current_vector = float_ortho[:1, :].flatten() + else: + current_vector = float_ortho[:, :1].flatten() + + if self.past_ortho_vector is not None: + cos_sim = F.cosine_similarity( + self.past_ortho_vector.unsqueeze(0), + current_vector.unsqueeze(0), + ).item() + + if len(self.queue) == self.queue_size: + self.queue.pop(0) + self.queue.append(cos_sim) + + if ( + len(self.queue) == self.queue_size + and sum(self.queue) / self.queue_size >= self.cos_threshold + ): + self.update_proj_gap = int(self.update_proj_gap * self.gamma_proj) + + self.past_ortho_vector = current_vector.clone() + + # ------------------------------------------------------------------ + # Quantized projection matrix storage + # ------------------------------------------------------------------ + + def _store_ortho(self, float_ortho: torch.Tensor) -> None: + """Store the orthogonal matrix, optionally quantized.""" + if self.quant: + q, scales, zeros, shape = _quantize( + float_ortho, + q_group_size=self.quant_group_size, + n_bit=self.quant_n_bit, + ) + self.ortho_matrix = q + self.ortho_matrix_scales = scales + self.ortho_matrix_zeros = zeros + self.ortho_matrix_shape = shape + else: + self.ortho_matrix = float_ortho + + def _load_ortho(self) -> torch.Tensor: + """Load the orthogonal matrix, dequantizing if necessary.""" + if self.quant: + return _dequantize( + self.ortho_matrix, + self.ortho_matrix_scales, + self.ortho_matrix_zeros, + self.ortho_matrix_shape, + ) + return self.ortho_matrix + + +# ====================================================================== +# Quantization utilities (shared with the optimizer) +# ====================================================================== + +@torch.no_grad() +def _quantize( + w: torch.Tensor, + q_group_size: int = -1, + n_bit: int = 8, +) -> tuple: + """Asymmetric min-max quantization to unsigned int. + + Returns: + ``(quantized_uint8, scales, zeros, original_shape)`` + """ + org_shape = w.shape + if q_group_size > 0: + assert w.nelement() % q_group_size == 0, ( + f"Tensor size {w.nelement()} not divisible by group_size {q_group_size}" + ) + w = w.reshape(-1, q_group_size) + assert w.dim() == 2 + + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2 ** n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + + w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) + w = w.reshape(org_shape).to(torch.uint8) + + return w, scales, zeros, org_shape + + +@torch.no_grad() +def _dequantize( + w: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + original_shape: tuple, +) -> torch.Tensor: + """Dequantize from uint8 back to float.""" + group_size = scales.shape[-1] + # Infer group size from scales shape vs original shape + total = w.numel() + n_groups = scales.numel() // scales.shape[-1] if scales.dim() > 1 else scales.numel() + if n_groups > 0: + group_size = total // n_groups + else: + group_size = total + + float_w = w.to(scales.dtype).reshape(-1, group_size) + float_w = (float_w - zeros) * scales + return float_w.reshape(original_shape) + + +@torch.no_grad() +def _quantize_stochastic( + w: torch.Tensor, + q_group_size: int = -1, + n_bit: int = 8, +) -> tuple: + """Asymmetric min-max quantization with stochastic rounding. + + Instead of deterministic ``round()``, the rounding direction is chosen + probabilistically proportional to the fractional part. This gives an + unbiased estimator of the original value in expectation. + + Returns: + ``(quantized_uint8, scales, zeros, original_shape)`` + """ + org_shape = w.shape + if q_group_size > 0: + assert w.nelement() % q_group_size == 0 + w = w.reshape(-1, q_group_size) + assert w.dim() == 2 + + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2 ** n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + + w_scaled = w / scales + up = torch.ceil(w_scaled) + down = torch.floor(w_scaled) + prob = w_scaled - down + rng = torch.rand_like(prob) + w = torch.where(rng < prob, up, down) + w = torch.clamp(w + zeros, min_int, max_int) + w = w.reshape(org_shape).to(torch.uint8) + + return w, scales, zeros, org_shape diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 8bb4440021..92b8fd7ad1 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -17,7 +17,7 @@ import psutil import warnings from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, List from functools import wraps import trl @@ -46,6 +46,7 @@ "unsloth_train", "_patch_trl_trainer", "UnslothVisionDataCollator", + "QGaloreConfig", ] logger = logging.getLogger(__name__) @@ -130,8 +131,33 @@ def unsloth_train(trainer, *args, **kwargs): from transformers import TrainingArguments +@dataclass +class QGaloreConfig: + """Configuration for Q-GaLore optimizer integration. + + Pass an instance of this class to ``UnslothTrainingArguments`` (via + ``q_galore_config``) to enable Q-GaLore training. + """ + rank: int = 256 + update_proj_gap: int = 200 + scale: float = 0.25 + proj_quant: bool = True + proj_quant_group_size: int = -1 + proj_quant_n_bit: int = 4 + weight_quant: bool = True + stochastic_round: bool = True + weight_group_size: int = 128 + per_layer: bool = False + cos_threshold: float = 0.4 + gamma_proj: float = 2.0 + queue_size: int = 5 + target_modules: Optional[List[str]] = None + + class UnslothTrainingArguments(TrainingArguments): - def __init__(self, embedding_learning_rate: float = None, *args, **kwargs): + def __init__(self, embedding_learning_rate: float = None, q_galore_config: Optional[QGaloreConfig] = None, *args, **kwargs): + self.q_galore_config = q_galore_config + embedding_learning_rate = embedding_learning_rate super().__init__(*args, **kwargs) self.embedding_learning_rate = embedding_learning_rate @@ -181,6 +207,12 @@ def _create_unsloth_optimizer( class UnslothTrainer(SFTTrainer): def create_optimizer(self): + # --- Q-GaLore optimizer --- + q_galore_config = getattr(self.args, "q_galore_config", None) + if q_galore_config is not None and self.optimizer is None: + return self._create_q_galore_optimizer(q_galore_config) + + # --- Embedding-LR optimizer --- embedding_learning_rate = getattr(self.args, "embedding_learning_rate", None) if embedding_learning_rate is None: return super().create_optimizer() @@ -197,6 +229,64 @@ def create_optimizer(self): ) return self.optimizer + def _create_q_galore_optimizer(self, config: "QGaloreConfig"): + """Build the Q-GaLore optimizer from a QGaloreConfig.""" + from unsloth.optimizers.q_galore_adamw import ( + QGaLoreAdamW8bit, + make_q_galore_param_groups, + ) + + lr = self.args.learning_rate + weight_decay = self.args.weight_decay + + param_groups = make_q_galore_param_groups( + self.model, + lr=lr, + weight_decay=weight_decay, + rank=config.rank, + update_proj_gap=config.update_proj_gap, + scale=config.scale, + proj_quant=config.proj_quant, + proj_quant_group_size=config.proj_quant_group_size, + proj_quant_n_bit=config.proj_quant_n_bit, + weight_quant=config.weight_quant, + stochastic_round=config.stochastic_round, + weight_group_size=config.weight_group_size, + cos_threshold=config.cos_threshold, + gamma_proj=config.gamma_proj, + queue_size=config.queue_size, + target_modules=config.target_modules, + ) + + self.optimizer = QGaLoreAdamW8bit( + param_groups, + lr=lr, + weight_decay=weight_decay, + ) + + # Initialize INT8 weight quantization if enabled + if config.weight_quant: + QGaLoreAdamW8bit.init_weight_quantization( + self.model, + param_groups, + group_size=config.weight_group_size, + stochastic=config.stochastic_round, + ) + + n_galore = sum( + len(g["params"]) for g in param_groups if "rank" in g + ) + n_other = sum( + len(g["params"]) for g in param_groups if "rank" not in g + ) + print( + f"🦥 Unsloth: Q-GaLore enabled — " + f"{n_galore} GaLore params (rank={config.rank}), " + f"{n_other} standard params." + ) + + return self.optimizer + # From `trl>=0.13.0`, they changed how to pass several params to the trainer # We need to patch to make the transition smooth From e1fd3d8355f3b97461bd83deecd594584f026be9 Mon Sep 17 00:00:00 2001 From: Avaya Aggarwal Date: Sat, 21 Mar 2026 23:43:49 +0530 Subject: [PATCH 02/12] feat: Implement QGaLoreAdamW8bit optimizer with 8-bit states, GaLore low-rank gradient projection, and optional INT8 weight quantization, along with supporting projector and tests. --- tests/utils/test_q_galore.py | 26 +++++++++++- unsloth/optimizers/q_galore_adamw.py | 50 +++++++++++++++--------- unsloth/optimizers/q_galore_projector.py | 10 ++--- unsloth/trainer.py | 3 +- 4 files changed, 59 insertions(+), 30 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 24a159f6d1..2bb334eb49 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -144,9 +144,9 @@ def test_scale_applied(self): full_half = proj.project_back(low) full_one = proj2.project_back(low2) - # The ratio should be approximately 0.5 + # The ratio should be exactly 0.5 (SVD is deterministic on same input) ratio = full_half.norm() / full_one.norm() - assert abs(ratio - 0.5) < 0.15 + assert abs(ratio - 0.5) < 1e-5, f"Expected ratio ~0.5, got {ratio:.8f}" # ====================================================================== @@ -251,6 +251,28 @@ def test_custom_target_modules(self): galore_group = [g for g in groups if "rank" in g][0] assert len(galore_group["params"]) == 1 # Only q_proj + def test_bias_excluded_from_galore(self): + """1D bias params matching target names must NOT be in the GaLore group. + + GaLoreProjector.project requires 2-D gradients, so bias vectors + (e.g. q_proj.bias) that match a target name must be excluded. + """ + model = nn.Module() + model.q_proj = nn.Linear(64, 64, bias=True) # has .weight AND .bias + model.embed = nn.Embedding(100, 64) + + groups = make_q_galore_param_groups(model, rank=8, weight_quant=False) + + galore_group = [g for g in groups if "rank" in g][0] + non_galore_group = [g for g in groups if "rank" not in g][0] + + # Only the 2-D q_proj.weight should be in the GaLore group + assert len(galore_group["params"]) == 1 + assert galore_group["params"][0].dim() == 2 + + # q_proj.bias (1-D) + embed.weight should be in non-GaLore + assert any(p.dim() == 1 for p in non_galore_group["params"]) + # ====================================================================== # Optimizer tests (CPU-only, no bitsandbytes dependency) diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index 85b933e8e9..bd34780670 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -17,7 +17,6 @@ # Layer-Adaptive Low-Rank Gradients" (arXiv:2407.08296) import torch -from collections import defaultdict from typing import Optional, List from .q_galore_projector import ( @@ -183,7 +182,6 @@ def step(self, closure=None): self.prefetch_state(p) self.update_step(group, p, gindex, pindex) - torch.cuda.synchronize() # --- GaLore project-back --- if "rank" in group: @@ -216,7 +214,12 @@ def step(self, closure=None): state["step"] += 1 - if self.is_paged: + # Sync once per param group (not per-param) to avoid excessive + # GPU stalls while still ensuring bnb async kernels complete. + if torch.cuda.is_available(): + torch.cuda.synchronize() + + if self.is_paged and torch.cuda.is_available(): torch.cuda.synchronize() return loss @@ -228,7 +231,11 @@ def step(self, closure=None): @staticmethod def _has_weight_quant(p: torch.Tensor, group: dict) -> bool: """Check if this parameter uses INT8 weight quantization.""" - return group.get("weight_quant", False) and hasattr(p, "_q_scales") + return ( + group.get("weight_quant", False) + and hasattr(p, "_q_scales") + and p._q_scales is not None # None means first step (not yet quantized) + ) @staticmethod def init_weight_quantization( @@ -237,12 +244,13 @@ def init_weight_quantization( group_size: int = 128, stochastic: bool = True, ) -> None: - """Initialize INT8 weight quantization for params in groups with - ``weight_quant=True``. + """Tag parameters for INT8 weight quantization. - This should be called once before the first optimizer step. It - quantizes eligible weights in-place and stores the quantization - metadata (scales, zeros, shape) as attributes on the parameter tensor. + This marks eligible weights with quantization metadata so that + the optimizer knows to quantize/dequantize them during ``step()``. + **Weights are NOT converted to uint8 here** — they remain in float + so that the first forward/backward pass runs correctly. The actual + quantization happens at the end of the first ``step()`` call. """ weight_quant_params = set() for group in param_groups: @@ -252,14 +260,13 @@ def init_weight_quantization( for name, p in model.named_parameters(): if id(p) in weight_quant_params: - quant_fn = _quantize_stochastic if stochastic else _quantize - q, scales, zeros, shape = quant_fn( - p.data, q_group_size=group_size, - ) - p.data = q.to(p.data.device) - p._q_scales = scales - p._q_zeros = zeros - p._q_shape = shape + # Store quantization metadata WITHOUT converting weights to + # uint8. The first optimizer.step() will quantize after the + # update. We store dummy scales/zeros so _has_weight_quant() + # returns True on the first step. + p._q_scales = None + p._q_zeros = None + p._q_shape = p.data.shape p._stochastic_round = stochastic p._weight_group_size = group_size @@ -330,8 +337,13 @@ def make_q_galore_param_groups( if not param.requires_grad: continue - # Check if any target module name appears as a component in the param name - is_galore = any(t in name for t in targets) + # Check if any target module name appears as a component in the param name. + # Exclude 1-D parameters (biases, norms) because GaLoreProjector.project + # requires 2-D gradients. + is_galore = ( + param.dim() >= 2 + and any(t in name for t in targets) + ) if is_galore: galore_params.append(param) diff --git a/unsloth/optimizers/q_galore_projector.py b/unsloth/optimizers/q_galore_projector.py index 00c6bade26..217566ed46 100644 --- a/unsloth/optimizers/q_galore_projector.py +++ b/unsloth/optimizers/q_galore_projector.py @@ -308,14 +308,10 @@ def _dequantize( original_shape: tuple, ) -> torch.Tensor: """Dequantize from uint8 back to float.""" - group_size = scales.shape[-1] - # Infer group size from scales shape vs original shape + # Infer group size: scales has shape (n_groups, 1), so n_groups = scales.shape[0] total = w.numel() - n_groups = scales.numel() // scales.shape[-1] if scales.dim() > 1 else scales.numel() - if n_groups > 0: - group_size = total // n_groups - else: - group_size = total + n_groups = scales.shape[0] if scales.dim() > 1 else scales.numel() + group_size = total // n_groups if n_groups > 0 else total float_w = w.to(scales.dtype).reshape(-1, group_size) float_w = (float_w - zeros) * scales diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 92b8fd7ad1..fe945b3701 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -147,7 +147,6 @@ class QGaloreConfig: weight_quant: bool = True stochastic_round: bool = True weight_group_size: int = 128 - per_layer: bool = False cos_threshold: float = 0.4 gamma_proj: float = 2.0 queue_size: int = 5 @@ -157,7 +156,7 @@ class QGaloreConfig: class UnslothTrainingArguments(TrainingArguments): def __init__(self, embedding_learning_rate: float = None, q_galore_config: Optional[QGaloreConfig] = None, *args, **kwargs): self.q_galore_config = q_galore_config - embedding_learning_rate = embedding_learning_rate + self.embedding_learning_rate = embedding_learning_rate super().__init__(*args, **kwargs) self.embedding_learning_rate = embedding_learning_rate From a45d267615bedbea5eb0017af449bc8c454ec8c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 18:00:29 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utils/test_q_galore.py | 105 ++++++++++++----------- unsloth/optimizers/q_galore_adamw.py | 100 +++++++++++---------- unsloth/optimizers/q_galore_projector.py | 73 ++++++++++------ unsloth/trainer.py | 55 ++++++------ 4 files changed, 188 insertions(+), 145 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 2bb334eb49..34b576b340 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -30,6 +30,7 @@ # Direct import of the actual modules (avoids unsloth/__init__.py) import importlib.util + def _load_module(name, filepath): spec = importlib.util.spec_from_file_location(name, filepath) mod = importlib.util.module_from_spec(spec) @@ -37,6 +38,7 @@ def _load_module(name, filepath): spec.loader.exec_module(mod) return mod + # Load projector module first (no dependencies on unsloth) _projector_mod = _load_module( "unsloth.optimizers.q_galore_projector", @@ -64,9 +66,9 @@ class TestGaLoreProjector: def test_project_and_back_tall(self): """Project → project_back preserves shape for tall matrices.""" - proj = GaLoreProjector(rank=4, update_proj_gap=1) + proj = GaLoreProjector(rank = 4, update_proj_gap = 1) grad = torch.randn(16, 8) # tall - low = proj.project(grad, step=0) + low = proj.project(grad, step = 0) assert low.shape == (16, 4) full = proj.project_back(low) @@ -74,9 +76,9 @@ def test_project_and_back_tall(self): def test_project_and_back_wide(self): """Project → project_back preserves shape for wide matrices.""" - proj = GaLoreProjector(rank=4, update_proj_gap=1) + proj = GaLoreProjector(rank = 4, update_proj_gap = 1) grad = torch.randn(8, 16) # wide - low = proj.project(grad, step=0) + low = proj.project(grad, step = 0) assert low.shape == (4, 16) full = proj.project_back(low) @@ -84,22 +86,22 @@ def test_project_and_back_wide(self): def test_project_reuses_cached_svd(self): """SVD is not recomputed when step is not a multiple of update_proj_gap.""" - proj = GaLoreProjector(rank=4, update_proj_gap=100) + proj = GaLoreProjector(rank = 4, update_proj_gap = 100) grad = torch.randn(16, 8) - proj.project(grad, step=0) + proj.project(grad, step = 0) assert proj.svd_count == 1 - proj.project(grad, step=1) + proj.project(grad, step = 1) assert proj.svd_count == 1 # No recomputation - proj.project(grad, step=100) + proj.project(grad, step = 100) assert proj.svd_count == 2 # Recomputed def test_quantized_projection(self): """Quantized projection matrix stores and restores with bounded error.""" - proj = GaLoreProjector(rank=4, update_proj_gap=1, quant=True, n_bit=8) + proj = GaLoreProjector(rank = 4, update_proj_gap = 1, quant = True, n_bit = 8) grad = torch.randn(16, 8) - low = proj.project(grad, step=0) + low = proj.project(grad, step = 0) assert low.shape == (16, 4) # The projection matrix should be stored as uint8 @@ -107,9 +109,9 @@ def test_quantized_projection(self): def test_quantized_projection_int4(self): """INT4 quantized projection stores correctly.""" - proj = GaLoreProjector(rank=4, update_proj_gap=1, quant=True, n_bit=4) + proj = GaLoreProjector(rank = 4, update_proj_gap = 1, quant = True, n_bit = 4) grad = torch.randn(16, 8) - proj.project(grad, step=0) + proj.project(grad, step = 0) assert proj.ortho_matrix.dtype == torch.uint8 # INT4 values should be in range [0, 15] assert proj.ortho_matrix.max() <= 15 @@ -117,29 +119,29 @@ def test_quantized_projection_int4(self): def test_adaptive_scheduling(self): """update_proj_gap increases when cosine similarity exceeds threshold.""" proj = GaLoreProjector( - rank=4, - update_proj_gap=10, - cos_threshold=0.0, # Very low threshold → always triggers - gamma_proj=2.0, - queue_size=2, + rank = 4, + update_proj_gap = 10, + cos_threshold = 0.0, # Very low threshold → always triggers + gamma_proj = 2.0, + queue_size = 2, ) # Use very similar gradients so cosine similarity is high base_grad = torch.randn(16, 8) for i in range(5): grad = base_grad + torch.randn_like(base_grad) * 0.001 - proj.project(grad, step=i * 10) + proj.project(grad, step = i * 10) # After several similar SVDs, update_proj_gap should have increased assert proj.update_proj_gap > 10 def test_scale_applied(self): """project_back applies the scale factor.""" - proj = GaLoreProjector(rank=4, update_proj_gap=1, scale=0.5) + proj = GaLoreProjector(rank = 4, update_proj_gap = 1, scale = 0.5) grad = torch.randn(16, 8) - low = proj.project(grad, step=0) + low = proj.project(grad, step = 0) - proj2 = GaLoreProjector(rank=4, update_proj_gap=1, scale=1.0) - low2 = proj2.project(grad, step=0) + proj2 = GaLoreProjector(rank = 4, update_proj_gap = 1, scale = 1.0) + low2 = proj2.project(grad, step = 0) full_half = proj.project_back(low) full_one = proj2.project_back(low2) @@ -160,7 +162,7 @@ class TestQuantizationUtils: def test_quantize_dequantize_roundtrip(self): """Quantize → dequantize has bounded error.""" w = torch.randn(32, 64) - q, scales, zeros, shape = _quantize(w, n_bit=8) + q, scales, zeros, shape = _quantize(w, n_bit = 8) w_hat = _dequantize(q, scales, zeros, shape) # Error should be bounded by the quantization step size @@ -170,7 +172,7 @@ def test_quantize_dequantize_roundtrip(self): def test_quantize_group_roundtrip(self): """Grouped quantization → dequantization has bounded error.""" w = torch.randn(32, 64) - q, scales, zeros, shape = _quantize(w, q_group_size=32, n_bit=8) + q, scales, zeros, shape = _quantize(w, q_group_size = 32, n_bit = 8) w_hat = _dequantize(q, scales, zeros, shape) error = (w - w_hat).abs().max() assert error < 0.1 @@ -178,13 +180,13 @@ def test_quantize_group_roundtrip(self): def test_quantize_dtype(self): """Quantized output should be uint8.""" w = torch.randn(16, 16) - q, _, _, _ = _quantize(w, n_bit=8) + q, _, _, _ = _quantize(w, n_bit = 8) assert q.dtype == torch.uint8 def test_quantize_int4_range(self): """INT4 values should be in [0, 15].""" w = torch.randn(16, 16) - q, _, _, _ = _quantize(w, n_bit=4) + q, _, _, _ = _quantize(w, n_bit = 4) assert q.max() <= 15 assert q.min() >= 0 @@ -194,14 +196,14 @@ def test_stochastic_rounding_unbiased(self): w = torch.randn(64, 64) errors = [] for _ in range(50): - q, scales, zeros, shape = _quantize_stochastic(w, n_bit=8) + q, scales, zeros, shape = _quantize_stochastic(w, n_bit = 8) w_hat = _dequantize(q, scales, zeros, shape) errors.append((w - w_hat).mean().item()) mean_error = sum(errors) / len(errors) - assert abs(mean_error) < 0.01, ( - f"Mean error {mean_error} suggests biased rounding" - ) + assert ( + abs(mean_error) < 0.01 + ), f"Mean error {mean_error} suggests biased rounding" # ====================================================================== @@ -217,12 +219,12 @@ def test_param_group_separation(self): # Create a mini-transformer-like model model = nn.Module() - model.q_proj = nn.Linear(64, 64, bias=False) - model.k_proj = nn.Linear(64, 64, bias=False) + model.q_proj = nn.Linear(64, 64, bias = False) + model.k_proj = nn.Linear(64, 64, bias = False) model.embed = nn.Embedding(100, 64) model.norm = nn.LayerNorm(64) - groups = make_q_galore_param_groups(model, rank=8, weight_quant=False) + groups = make_q_galore_param_groups(model, rank = 8, weight_quant = False) # Should have 2 groups: galore and non-galore assert len(groups) == 2 @@ -233,19 +235,24 @@ def test_param_group_separation(self): # q_proj and k_proj should be in galore group (2 params) assert len(galore_group["params"]) == 2 # embed and norm should be in non-galore group - assert len(non_galore_group["params"]) == 3 # embed weight + norm weight + norm bias + assert ( + len(non_galore_group["params"]) == 3 + ) # embed weight + norm weight + norm bias def test_custom_target_modules(self): """Custom target_modules narrows GaLore scope.""" model = nn.Module() - model.q_proj = nn.Linear(64, 64, bias=False) - model.k_proj = nn.Linear(64, 64, bias=False) - model.v_proj = nn.Linear(64, 64, bias=False) + model.q_proj = nn.Linear(64, 64, bias = False) + model.k_proj = nn.Linear(64, 64, bias = False) + model.v_proj = nn.Linear(64, 64, bias = False) model.embed = nn.Embedding(100, 64) groups = make_q_galore_param_groups( - model, rank=8, target_modules=["q_proj"], weight_quant=False, + model, + rank = 8, + target_modules = ["q_proj"], + weight_quant = False, ) galore_group = [g for g in groups if "rank" in g][0] @@ -287,12 +294,12 @@ def test_projector_training_loop(self): torch.manual_seed(42) # Tiny model: single linear layer - model = nn.Linear(32, 16, bias=False) + model = nn.Linear(32, 16, bias = False) target = torch.randn(4, 16) x = torch.randn(4, 32) - proj = GaLoreProjector(rank=8, update_proj_gap=1, scale=1.0) - optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) + proj = GaLoreProjector(rank = 8, update_proj_gap = 1, scale = 1.0) + optimizer = torch.optim.AdamW(model.parameters(), lr = 0.01) losses = [] for step in range(20): @@ -315,9 +322,9 @@ def test_projector_training_loop(self): optimizer.step() # Loss should decrease - assert losses[-1] < losses[0], ( - f"Loss did not decrease: {losses[0]:.4f} → {losses[-1]:.4f}" - ) + assert ( + losses[-1] < losses[0] + ), f"Loss did not decrease: {losses[0]:.4f} → {losses[-1]:.4f}" def test_full_projector_roundtrip_quality(self): """project → project_back captures the dominant gradient directions.""" @@ -327,13 +334,13 @@ def test_full_projector_roundtrip_quality(self): v = torch.randn(4, 16) grad = u @ v # rank-4 gradient - proj = GaLoreProjector(rank=4, update_proj_gap=1, scale=1.0) - low = proj.project(grad, step=0) + proj = GaLoreProjector(rank = 4, update_proj_gap = 1, scale = 1.0) + low = proj.project(grad, step = 0) reconstructed = proj.project_back(low) # For a rank-4 gradient with rank-4 projection, reconstruction # should be very close to original relative_error = (grad - reconstructed).norm() / grad.norm() - assert relative_error < 0.05, ( - f"Reconstruction error too high: {relative_error:.4f}" - ) + assert ( + relative_error < 0.05 + ), f"Reconstruction error too high: {relative_error:.4f}" diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index bd34780670..a6b9910746 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -98,7 +98,7 @@ def __init__( min_8bit_size, percentile_clipping, block_wise, - is_paged=is_paged, + is_paged = is_paged, ) # ------------------------------------------------------------------ @@ -106,7 +106,7 @@ def __init__( # ------------------------------------------------------------------ @torch.no_grad() - def step(self, closure=None): + def step(self, closure = None): """Perform a single optimization step. For each parameter that has a ``rank`` key in its param group, the @@ -142,7 +142,10 @@ def step(self, closure=None): # --- Dequantize weight if INT8 --- if has_weight_quant: float_weight = _dequantize( - p.data, p._q_scales, p._q_zeros, p._q_shape, + p.data, + p._q_scales, + p._q_zeros, + p._q_shape, ) p.data = float_weight.clone().to(p.data.device) @@ -150,16 +153,16 @@ def step(self, closure=None): if "rank" in group: if "projector" not in state: state["projector"] = GaLoreProjector( - rank=group["rank"], - update_proj_gap=group.get("update_proj_gap", 200), - scale=group.get("scale", 0.25), - proj_type=group.get("proj_type", "std"), - quant=group.get("quant", False), - group_size=group.get("quant_group_size", -1), - n_bit=group.get("quant_n_bit", 4), - cos_threshold=group.get("cos_threshold", 0.4), - gamma_proj=group.get("gamma_proj", 2.0), - queue_size=group.get("queue_size", 5), + rank = group["rank"], + update_proj_gap = group.get("update_proj_gap", 200), + scale = group.get("scale", 0.25), + proj_type = group.get("proj_type", "std"), + quant = group.get("quant", False), + group_size = group.get("quant_group_size", -1), + n_bit = group.get("quant_n_bit", 4), + cos_threshold = group.get("cos_threshold", 0.4), + gamma_proj = group.get("gamma_proj", 2.0), + queue_size = group.get("queue_size", 5), ) # Temporarily disable weight decay for GaLore params @@ -173,7 +176,9 @@ def step(self, closure=None): # Save current weight; replace p.data with zeros so # the 8-bit update writes the pure weight delta. p._saved_data = p.data.clone() - p.data = torch.zeros_like(grad, dtype=p.data.dtype, device=p.data.device) + p.data = torch.zeros_like( + grad, dtype = p.data.dtype, device = p.data.device + ) p.grad = grad # --- 8-bit Adam update --- @@ -186,16 +191,14 @@ def step(self, closure=None): # --- GaLore project-back --- if "rank" in group: # p.data now holds the weight update in low-rank space - p.data = p._saved_data.add_( - state["projector"].project_back(p.data) - ) + p.data = p._saved_data.add_(state["projector"].project_back(p.data)) del p._saved_data # Re-apply weight decay if "_wd_saved" in group: p.data.add_( p.data, - alpha=-group["lr"] * group["_wd_saved"], + alpha = -group["lr"] * group["_wd_saved"], ) group["weight_decay"] = group["_wd_saved"] del group["_wd_saved"] @@ -206,7 +209,7 @@ def step(self, closure=None): stochastic = group.get("stochastic_round", True) gsize = group.get("weight_group_size", 128) quant_fn = _quantize_stochastic if stochastic else _quantize - q, scales, zeros, shape = quant_fn(saved, q_group_size=gsize) + q, scales, zeros, shape = quant_fn(saved, q_group_size = gsize) p.data = q.to(p.data.device) p._q_scales = scales p._q_zeros = zeros @@ -277,8 +280,13 @@ def init_weight_quantization( # Default linear layer names in transformer blocks that should use GaLore. _DEFAULT_GALORE_TARGETS = { - "q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj", + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", } @@ -353,30 +361,34 @@ def make_q_galore_param_groups( groups = [] if galore_params: - groups.append({ - "params": galore_params, - "lr": lr, - "weight_decay": weight_decay, - "rank": rank, - "update_proj_gap": update_proj_gap, - "scale": scale, - "proj_type": "std", - "quant": proj_quant, - "quant_group_size": proj_quant_group_size, - "quant_n_bit": proj_quant_n_bit, - "weight_quant": weight_quant, - "stochastic_round": stochastic_round, - "weight_group_size": weight_group_size, - "cos_threshold": cos_threshold, - "gamma_proj": gamma_proj, - "queue_size": queue_size, - }) + groups.append( + { + "params": galore_params, + "lr": lr, + "weight_decay": weight_decay, + "rank": rank, + "update_proj_gap": update_proj_gap, + "scale": scale, + "proj_type": "std", + "quant": proj_quant, + "quant_group_size": proj_quant_group_size, + "quant_n_bit": proj_quant_n_bit, + "weight_quant": weight_quant, + "stochastic_round": stochastic_round, + "weight_group_size": weight_group_size, + "cos_threshold": cos_threshold, + "gamma_proj": gamma_proj, + "queue_size": queue_size, + } + ) if non_galore_params: - groups.append({ - "params": non_galore_params, - "lr": lr, - "weight_decay": weight_decay, - }) + groups.append( + { + "params": non_galore_params, + "lr": lr, + "weight_decay": weight_decay, + } + ) return groups diff --git a/unsloth/optimizers/q_galore_projector.py b/unsloth/optimizers/q_galore_projector.py index 217566ed46..94b8d7c105 100644 --- a/unsloth/optimizers/q_galore_projector.py +++ b/unsloth/optimizers/q_galore_projector.py @@ -57,11 +57,23 @@ class GaLoreProjector: """ __slots__ = ( - "rank", "update_proj_gap", "scale", "proj_type", - "quant", "quant_group_size", "quant_n_bit", - "cos_threshold", "gamma_proj", "queue_size", - "ortho_matrix", "ortho_matrix_scales", "ortho_matrix_zeros", - "ortho_matrix_shape", "past_ortho_vector", "queue", "svd_count", + "rank", + "update_proj_gap", + "scale", + "proj_type", + "quant", + "quant_group_size", + "quant_n_bit", + "cos_threshold", + "gamma_proj", + "queue_size", + "ortho_matrix", + "ortho_matrix_scales", + "ortho_matrix_zeros", + "ortho_matrix_shape", + "past_ortho_vector", + "queue", + "svd_count", ) def __init__( @@ -125,9 +137,11 @@ def project(self, full_rank_grad: torch.Tensor, step: int) -> torch.Tensor: # "tall" matrix → right projection (grad @ Q^T) if self.ortho_matrix is None or step % self.update_proj_gap == 0: float_ortho = self._compute_orthogonal( - full_rank_grad, self.rank, side="right", + full_rank_grad, + self.rank, + side = "right", ) - self._update_adaptive_schedule(float_ortho, side="right") + self._update_adaptive_schedule(float_ortho, side = "right") self._store_ortho(float_ortho) float_ortho = self._load_ortho() @@ -136,9 +150,11 @@ def project(self, full_rank_grad: torch.Tensor, step: int) -> torch.Tensor: # "wide" matrix → left projection (Q^T @ grad) if self.ortho_matrix is None or step % self.update_proj_gap == 0: float_ortho = self._compute_orthogonal( - full_rank_grad, self.rank, side="left", + full_rank_grad, + self.rank, + side = "left", ) - self._update_adaptive_schedule(float_ortho, side="left") + self._update_adaptive_schedule(float_ortho, side = "left") self._store_ortho(float_ortho) float_ortho = self._load_ortho() @@ -170,7 +186,9 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: @staticmethod def _compute_orthogonal( - weights: torch.Tensor, rank: int, side: str, + weights: torch.Tensor, + rank: int, + side: str, ) -> torch.Tensor: """Compute the top-``rank`` orthogonal matrix via truncated SVD. @@ -187,7 +205,7 @@ def _compute_orthogonal( matrix = weights.float() if original_dtype != torch.float32 else weights - U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + U, s, Vh = torch.linalg.svd(matrix, full_matrices = False) if side == "right": result = Vh[:rank, :] @@ -197,7 +215,7 @@ def _compute_orthogonal( raise ValueError(f"side must be 'left' or 'right', got '{side}'") if original_dtype != torch.float32: - result = result.to(device=original_device, dtype=original_dtype) + result = result.to(device = original_device, dtype = original_dtype) return result # ------------------------------------------------------------------ @@ -205,7 +223,9 @@ def _compute_orthogonal( # ------------------------------------------------------------------ def _update_adaptive_schedule( - self, float_ortho: torch.Tensor, side: str, + self, + float_ortho: torch.Tensor, + side: str, ) -> None: """Track subspace stability and increase ``update_proj_gap`` if stable.""" self.svd_count += 1 @@ -242,8 +262,8 @@ def _store_ortho(self, float_ortho: torch.Tensor) -> None: if self.quant: q, scales, zeros, shape = _quantize( float_ortho, - q_group_size=self.quant_group_size, - n_bit=self.quant_n_bit, + q_group_size = self.quant_group_size, + n_bit = self.quant_n_bit, ) self.ortho_matrix = q self.ortho_matrix_scales = scales @@ -268,6 +288,7 @@ def _load_ortho(self) -> torch.Tensor: # Quantization utilities (shared with the optimizer) # ====================================================================== + @torch.no_grad() def _quantize( w: torch.Tensor, @@ -281,17 +302,17 @@ def _quantize( """ org_shape = w.shape if q_group_size > 0: - assert w.nelement() % q_group_size == 0, ( - f"Tensor size {w.nelement()} not divisible by group_size {q_group_size}" - ) + assert ( + w.nelement() % q_group_size == 0 + ), f"Tensor size {w.nelement()} not divisible by group_size {q_group_size}" w = w.reshape(-1, q_group_size) assert w.dim() == 2 - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2 ** n_bit - 1 + max_val = w.amax(dim = 1, keepdim = True) + min_val = w.amin(dim = 1, keepdim = True) + max_int = 2**n_bit - 1 min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int + scales = (max_val - min_val).clamp(min = 1e-5) / max_int zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) @@ -339,11 +360,11 @@ def _quantize_stochastic( w = w.reshape(-1, q_group_size) assert w.dim() == 2 - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2 ** n_bit - 1 + max_val = w.amax(dim = 1, keepdim = True) + min_val = w.amin(dim = 1, keepdim = True) + max_int = 2**n_bit - 1 min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int + scales = (max_val - min_val).clamp(min = 1e-5) / max_int zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) w_scaled = w / scales diff --git a/unsloth/trainer.py b/unsloth/trainer.py index fe945b3701..8ff81d01d4 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -138,6 +138,7 @@ class QGaloreConfig: Pass an instance of this class to ``UnslothTrainingArguments`` (via ``q_galore_config``) to enable Q-GaLore training. """ + rank: int = 256 update_proj_gap: int = 200 scale: float = 0.25 @@ -154,7 +155,13 @@ class QGaloreConfig: class UnslothTrainingArguments(TrainingArguments): - def __init__(self, embedding_learning_rate: float = None, q_galore_config: Optional[QGaloreConfig] = None, *args, **kwargs): + def __init__( + self, + embedding_learning_rate: float = None, + q_galore_config: Optional[QGaloreConfig] = None, + *args, + **kwargs, + ): self.q_galore_config = q_galore_config self.embedding_learning_rate = embedding_learning_rate super().__init__(*args, **kwargs) @@ -240,27 +247,27 @@ def _create_q_galore_optimizer(self, config: "QGaloreConfig"): param_groups = make_q_galore_param_groups( self.model, - lr=lr, - weight_decay=weight_decay, - rank=config.rank, - update_proj_gap=config.update_proj_gap, - scale=config.scale, - proj_quant=config.proj_quant, - proj_quant_group_size=config.proj_quant_group_size, - proj_quant_n_bit=config.proj_quant_n_bit, - weight_quant=config.weight_quant, - stochastic_round=config.stochastic_round, - weight_group_size=config.weight_group_size, - cos_threshold=config.cos_threshold, - gamma_proj=config.gamma_proj, - queue_size=config.queue_size, - target_modules=config.target_modules, + lr = lr, + weight_decay = weight_decay, + rank = config.rank, + update_proj_gap = config.update_proj_gap, + scale = config.scale, + proj_quant = config.proj_quant, + proj_quant_group_size = config.proj_quant_group_size, + proj_quant_n_bit = config.proj_quant_n_bit, + weight_quant = config.weight_quant, + stochastic_round = config.stochastic_round, + weight_group_size = config.weight_group_size, + cos_threshold = config.cos_threshold, + gamma_proj = config.gamma_proj, + queue_size = config.queue_size, + target_modules = config.target_modules, ) self.optimizer = QGaLoreAdamW8bit( param_groups, - lr=lr, - weight_decay=weight_decay, + lr = lr, + weight_decay = weight_decay, ) # Initialize INT8 weight quantization if enabled @@ -268,16 +275,12 @@ def _create_q_galore_optimizer(self, config: "QGaloreConfig"): QGaLoreAdamW8bit.init_weight_quantization( self.model, param_groups, - group_size=config.weight_group_size, - stochastic=config.stochastic_round, + group_size = config.weight_group_size, + stochastic = config.stochastic_round, ) - n_galore = sum( - len(g["params"]) for g in param_groups if "rank" in g - ) - n_other = sum( - len(g["params"]) for g in param_groups if "rank" not in g - ) + n_galore = sum(len(g["params"]) for g in param_groups if "rank" in g) + n_other = sum(len(g["params"]) for g in param_groups if "rank" not in g) print( f"🦥 Unsloth: Q-GaLore enabled — " f"{n_galore} GaLore params (rank={config.rank}), " From d6deae4932fe68e207563de6ebb7bbdc8e867d9c Mon Sep 17 00:00:00 2001 From: Avaya Aggarwal Date: Sun, 22 Mar 2026 13:25:02 +0530 Subject: [PATCH 04/12] feat: Introduce Q-GaLore AdamW optimizer with low-rank quantized gradient projection and integrate into the trainer, along with dedicated tests. --- tests/utils/test_q_galore.py | 76 ++++++++++++++++++++++++++++ unsloth/optimizers/q_galore_adamw.py | 19 +++---- unsloth/trainer.py | 46 ++++++++++++++++- 3 files changed, 130 insertions(+), 11 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 34b576b340..5c5bb044d6 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -344,3 +344,79 @@ def test_full_projector_roundtrip_quality(self): assert ( relative_error < 0.05 ), f"Reconstruction error too high: {relative_error:.4f}" + + def test_weight_quant_activates_on_first_step(self): + """_has_weight_quant returns True even when _q_scales is None (first step).""" + _adamw_mod_local = sys.modules["unsloth.optimizers.q_galore_adamw"] + QGaLoreAdamW8bit = _adamw_mod_local.QGaLoreAdamW8bit + + p = torch.nn.Parameter(torch.randn(16, 16)) + # Simulate init_weight_quantization tagging + p._q_scales = None + p._q_zeros = None + p._q_shape = p.data.shape + + group = {"weight_quant": True} + + # _has_weight_quant must return True even on first step (_q_scales=None) + assert QGaLoreAdamW8bit._has_weight_quant(p, group) is True + + # Without the tag, it should return False + p2 = torch.nn.Parameter(torch.randn(16, 16)) + assert QGaLoreAdamW8bit._has_weight_quant(p2, group) is False + + def test_embedding_lr_param_group_split(self): + """Embedding params can be split into a separate group with custom LR.""" + # This tests the logic that make_q_galore_param_groups produces groups + # that can be further split by the trainer for embedding LR. + model = nn.Module() + model.q_proj = nn.Linear(64, 64, bias=False) + model.embed = nn.Embedding(100, 64) + + groups = make_q_galore_param_groups(model, rank=8, weight_quant=False) + + # Simulate splitting non-GaLore group for embedding LR + embed_lr = 5e-5 + new_groups = [] + for group in groups: + if "rank" in group: + new_groups.append(group) + continue + embed_params = [] + other_params = [] + for p in group["params"]: + # In real usage, we'd check the name; here just split by shape + if p.shape[0] == 100: # embedding + embed_params.append(p) + else: + other_params.append(p) + if other_params: + g = dict(group) + g["params"] = other_params + new_groups.append(g) + if embed_params: + g = dict(group) + g["params"] = embed_params + g["lr"] = embed_lr + new_groups.append(g) + + # Should have 3 groups: galore, non-galore non-embed, embed + embed_groups = [g for g in new_groups if g.get("lr") == embed_lr] + assert len(embed_groups) == 1 + assert embed_groups[0]["lr"] == embed_lr + + def test_optimizer_hyperparams_forwarded(self): + """QGaLoreAdamW8bit accepts betas and eps keyword arguments.""" + # Verify the constructor signature accepts these params. + # Without bitsandbytes we can't instantiate, but we can check the + # function signature. + import inspect + + _adamw_mod_local = sys.modules["unsloth.optimizers.q_galore_adamw"] + QGaLoreAdamW8bit = _adamw_mod_local.QGaLoreAdamW8bit + + sig = inspect.signature(QGaLoreAdamW8bit.__init__) + param_names = list(sig.parameters.keys()) + assert "betas" in param_names, "betas not in QGaLoreAdamW8bit.__init__ params" + assert "eps" in param_names, "eps not in QGaLoreAdamW8bit.__init__ params" + diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index a6b9910746..b02d90e03b 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -141,13 +141,15 @@ def step(self, closure = None): # --- Dequantize weight if INT8 --- if has_weight_quant: - float_weight = _dequantize( - p.data, - p._q_scales, - p._q_zeros, - p._q_shape, - ) - p.data = float_weight.clone().to(p.data.device) + if p._q_scales is not None: + float_weight = _dequantize( + p.data, + p._q_scales, + p._q_zeros, + p._q_shape, + ) + p.data = float_weight.clone().to(p.data.device) + # else: first step, weights are still float — skip dequantize # --- GaLore projection --- if "rank" in group: @@ -236,8 +238,7 @@ def _has_weight_quant(p: torch.Tensor, group: dict) -> bool: """Check if this parameter uses INT8 weight quantization.""" return ( group.get("weight_quant", False) - and hasattr(p, "_q_scales") - and p._q_scales is not None # None means first step (not yet quantized) + and hasattr(p, "_q_scales") # tag set by init_weight_quantization() ) @staticmethod diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 8ff81d01d4..abeff784c4 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -216,7 +216,8 @@ def create_optimizer(self): # --- Q-GaLore optimizer --- q_galore_config = getattr(self.args, "q_galore_config", None) if q_galore_config is not None and self.optimizer is None: - return self._create_q_galore_optimizer(q_galore_config) + embedding_lr = getattr(self.args, "embedding_learning_rate", None) + return self._create_q_galore_optimizer(q_galore_config, embedding_lr) # --- Embedding-LR optimizer --- embedding_learning_rate = getattr(self.args, "embedding_learning_rate", None) @@ -235,7 +236,7 @@ def create_optimizer(self): ) return self.optimizer - def _create_q_galore_optimizer(self, config: "QGaloreConfig"): + def _create_q_galore_optimizer(self, config: "QGaloreConfig", embedding_lr=None): """Build the Q-GaLore optimizer from a QGaloreConfig.""" from unsloth.optimizers.q_galore_adamw import ( QGaLoreAdamW8bit, @@ -264,10 +265,51 @@ def _create_q_galore_optimizer(self, config: "QGaloreConfig"): target_modules = config.target_modules, ) + # --- Split embedding params with custom LR (Fix #2) --- + if embedding_lr is not None: + new_groups = [] + for group in param_groups: + if "rank" in group: + # GaLore group — keep as-is (embeddings are never in here) + new_groups.append(group) + continue + # Non-GaLore group: split out embedding params + embed_params = [] + other_params = [] + for p in group["params"]: + # Check if this param belongs to a modules_to_save embedding + is_embed = False + for name, param in self.model.named_parameters(): + if param is p and name.endswith("modules_to_save.default.weight"): + partial_name = name[: -len(".modules_to_save.default.weight")] + partial_name = partial_name[partial_name.rfind(".") + 1 :] + print( + f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}." + ) + is_embed = True + break + if is_embed: + embed_params.append(p) + else: + other_params.append(p) + if other_params: + other_group = dict(group) + other_group["params"] = other_params + new_groups.append(other_group) + if embed_params: + embed_group = dict(group) + embed_group["params"] = embed_params + embed_group["lr"] = embedding_lr + new_groups.append(embed_group) + param_groups = new_groups + + # --- Forward optimizer hyperparameters (Fix #3) --- self.optimizer = QGaLoreAdamW8bit( param_groups, lr = lr, weight_decay = weight_decay, + betas = (self.args.adam_beta1, self.args.adam_beta2), + eps = self.args.adam_epsilon, ) # Initialize INT8 weight quantization if enabled From 093bc41351429d6e55ed79feb42ed68798e766b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 18:16:26 +0000 Subject: [PATCH 05/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utils/test_q_galore.py | 4 ++-- unsloth/optimizers/q_galore_adamw.py | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 5c5bb044d6..36636b0fa1 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -265,10 +265,10 @@ def test_bias_excluded_from_galore(self): (e.g. q_proj.bias) that match a target name must be excluded. """ model = nn.Module() - model.q_proj = nn.Linear(64, 64, bias=True) # has .weight AND .bias + model.q_proj = nn.Linear(64, 64, bias = True) # has .weight AND .bias model.embed = nn.Embedding(100, 64) - groups = make_q_galore_param_groups(model, rank=8, weight_quant=False) + groups = make_q_galore_param_groups(model, rank = 8, weight_quant = False) galore_group = [g for g in groups if "rank" in g][0] non_galore_group = [g for g in groups if "rank" not in g][0] diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index b02d90e03b..706c1b9f6a 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -349,10 +349,7 @@ def make_q_galore_param_groups( # Check if any target module name appears as a component in the param name. # Exclude 1-D parameters (biases, norms) because GaLoreProjector.project # requires 2-D gradients. - is_galore = ( - param.dim() >= 2 - and any(t in name for t in targets) - ) + is_galore = param.dim() >= 2 and any(t in name for t in targets) if is_galore: galore_params.append(param) From bb3fa42ebfb600fcb52cc762b1bc07fed88e5e0d Mon Sep 17 00:00:00 2001 From: Avaya Aggarwal Date: Sun, 22 Mar 2026 22:56:24 +0530 Subject: [PATCH 06/12] feat: Implement Q-GaLore AdamW optimizer with gradient projection and quantization, including trainer integration and corresponding tests. --- tests/utils/test_q_galore.py | 61 ++++++++++++++++++++++++++++ unsloth/optimizers/q_galore_adamw.py | 16 ++++---- unsloth/trainer.py | 21 +++++----- 3 files changed, 80 insertions(+), 18 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 36636b0fa1..1d14a67a78 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -280,6 +280,22 @@ def test_bias_excluded_from_galore(self): # q_proj.bias (1-D) + embed.weight should be in non-GaLore assert any(p.dim() == 1 for p in non_galore_group["params"]) + def test_empty_target_modules_no_galore(self): + """target_modules=[] should result in no GaLore params.""" + model = nn.Module() + model.q_proj = nn.Linear(64, 64, bias = False) + + # Pass empty list, should NOT fall back to defaults + groups = make_q_galore_param_groups( + model, + rank = 8, + target_modules = [], + weight_quant = False, + ) + + galore_groups = [g for g in groups if "rank" in g] + assert len(galore_groups) == 0, "Expected no GaLore groups when target_modules=[]" + # ====================================================================== # Optimizer tests (CPU-only, no bitsandbytes dependency) @@ -420,3 +436,48 @@ def test_optimizer_hyperparams_forwarded(self): assert "betas" in param_names, "betas not in QGaLoreAdamW8bit.__init__ params" assert "eps" in param_names, "eps not in QGaLoreAdamW8bit.__init__ params" + def test_weight_decay_uses_saved_data(self): + """Weight decay should apply to the pre-updated weights, not post-updated.""" + _adamw_mod_local = sys.modules["unsloth.optimizers.q_galore_adamw"] + + # We can test this logic without bitsandbytes by mocking the optimizer step + # Create a mock parameter and group + p = torch.nn.Parameter(torch.ones(4, 4)) + p._saved_data = torch.ones(4, 4) * 2.0 # Pre-update weights + p.data = torch.ones(4, 4) * 3.0 # Post-update weights (GaLore output) + + group = {"weight_decay": 0.1, "lr": 1.0, "_wd_saved": 0.1} + + # Replicate the decoupled weight decay logic + p.data.add_( + p._saved_data, + alpha = -group["lr"] * group["_wd_saved"], + ) + + # If it used p.data (wrong), value would be 3.0 - (1.0 * 0.1 * 3.0) = 2.7 + # If it used p._saved_data (correct), value is 3.0 - (1.0 * 0.1 * 2.0) = 2.8 + assert torch.allclose(p.data, torch.tensor(2.8)), "Weight decay didn't use _saved_data!" + + def test_params_float_after_weight_quant_step(self): + """After a step with weight_quant=True, parameters must remain floating point.""" + _adamw_mod_local = sys.modules["unsloth.optimizers.q_galore_adamw"] + _projector_mod_local = sys.modules["unsloth.optimizers.q_galore_projector"] + + _quantize = _projector_mod_local._quantize + + p = torch.nn.Parameter(torch.randn(16, 16)) + group = {"weight_quant": True, "stochastic_round": False, "weight_group_size": 16} + + # Replicate the re-quantize logic at the end of optimizer step + float_data = p.data.clone() + q, scales, zeros, shape = _quantize(float_data, q_group_size=group["weight_group_size"]) + + # The key assertion: p.data stays float, _q_data holds uint8 + p._q_data = q.to(p.data.device) + p._q_scales = scales + p._q_zeros = zeros + p._q_shape = shape + + assert p.data.is_floating_point(), "p.data was converted to uint8!" + assert p._q_data.dtype == torch.uint8, "_q_data should be uint8!" + diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index 706c1b9f6a..0975b79b59 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -143,7 +143,7 @@ def step(self, closure = None): if has_weight_quant: if p._q_scales is not None: float_weight = _dequantize( - p.data, + p._q_data, p._q_scales, p._q_zeros, p._q_shape, @@ -196,10 +196,10 @@ def step(self, closure = None): p.data = p._saved_data.add_(state["projector"].project_back(p.data)) del p._saved_data - # Re-apply weight decay + # Re-apply decoupled weight decay using pre-update weights if "_wd_saved" in group: p.data.add_( - p.data, + p._saved_data, alpha = -group["lr"] * group["_wd_saved"], ) group["weight_decay"] = group["_wd_saved"] @@ -207,12 +207,14 @@ def step(self, closure = None): # --- Re-quantize weight to INT8 --- if has_weight_quant: - saved = p.data.clone() + float_data = p.data.clone() stochastic = group.get("stochastic_round", True) gsize = group.get("weight_group_size", 128) quant_fn = _quantize_stochastic if stochastic else _quantize - q, scales, zeros, shape = quant_fn(saved, q_group_size = gsize) - p.data = q.to(p.data.device) + q, scales, zeros, shape = quant_fn(float_data, q_group_size = gsize) + # Keep p.data as float for the next forward/backward pass. + # Store quantized representation in _q_data for compressed storage. + p._q_data = q.to(p.data.device) p._q_scales = scales p._q_zeros = zeros p._q_shape = shape @@ -337,7 +339,7 @@ def make_q_galore_param_groups( Returns: List of two param group dicts: ``[galore_group, non_galore_group]``. """ - targets = set(target_modules) if target_modules else _DEFAULT_GALORE_TARGETS + targets = set(target_modules) if target_modules is not None else _DEFAULT_GALORE_TARGETS galore_params = [] non_galore_params = [] diff --git a/unsloth/trainer.py b/unsloth/trainer.py index abeff784c4..2b94c52755 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -267,6 +267,9 @@ def _create_q_galore_optimizer(self, config: "QGaloreConfig", embedding_lr=None) # --- Split embedding params with custom LR (Fix #2) --- if embedding_lr is not None: + # Build a fast param->name lookup (O(N) instead of O(N*M)) + param_to_name = {id(p): name for name, p in self.model.named_parameters()} + new_groups = [] for group in param_groups: if "rank" in group: @@ -278,17 +281,13 @@ def _create_q_galore_optimizer(self, config: "QGaloreConfig", embedding_lr=None) other_params = [] for p in group["params"]: # Check if this param belongs to a modules_to_save embedding - is_embed = False - for name, param in self.model.named_parameters(): - if param is p and name.endswith("modules_to_save.default.weight"): - partial_name = name[: -len(".modules_to_save.default.weight")] - partial_name = partial_name[partial_name.rfind(".") + 1 :] - print( - f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}." - ) - is_embed = True - break - if is_embed: + name = param_to_name.get(id(p)) + if name and name.endswith("modules_to_save.default.weight"): + partial_name = name[: -len(".modules_to_save.default.weight")] + partial_name = partial_name[partial_name.rfind(".") + 1 :] + print( + f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}." + ) embed_params.append(p) else: other_params.append(p) From 83cf384649fa8d98545b2b4aa3a85a7ab67dd262 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Mar 2026 07:55:33 +0000 Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utils/test_q_galore.py | 23 +++++++++++------------ unsloth/trainer.py | 2 +- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 1d14a67a78..9f336a00c2 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -386,10 +386,10 @@ def test_embedding_lr_param_group_split(self): # This tests the logic that make_q_galore_param_groups produces groups # that can be further split by the trainer for embedding LR. model = nn.Module() - model.q_proj = nn.Linear(64, 64, bias=False) + model.q_proj = nn.Linear(64, 64, bias = False) model.embed = nn.Embedding(100, 64) - groups = make_q_galore_param_groups(model, rank=8, weight_quant=False) + groups = make_q_galore_param_groups(model, rank = 8, weight_quant = False) # Simulate splitting non-GaLore group for embedding LR embed_lr = 5e-5 @@ -439,21 +439,21 @@ def test_optimizer_hyperparams_forwarded(self): def test_weight_decay_uses_saved_data(self): """Weight decay should apply to the pre-updated weights, not post-updated.""" _adamw_mod_local = sys.modules["unsloth.optimizers.q_galore_adamw"] - + # We can test this logic without bitsandbytes by mocking the optimizer step # Create a mock parameter and group p = torch.nn.Parameter(torch.ones(4, 4)) p._saved_data = torch.ones(4, 4) * 2.0 # Pre-update weights p.data = torch.ones(4, 4) * 3.0 # Post-update weights (GaLore output) - + group = {"weight_decay": 0.1, "lr": 1.0, "_wd_saved": 0.1} - + # Replicate the decoupled weight decay logic p.data.add_( p._saved_data, alpha = -group["lr"] * group["_wd_saved"], ) - + # If it used p.data (wrong), value would be 3.0 - (1.0 * 0.1 * 3.0) = 2.7 # If it used p._saved_data (correct), value is 3.0 - (1.0 * 0.1 * 2.0) = 2.8 assert torch.allclose(p.data, torch.tensor(2.8)), "Weight decay didn't use _saved_data!" @@ -462,22 +462,21 @@ def test_params_float_after_weight_quant_step(self): """After a step with weight_quant=True, parameters must remain floating point.""" _adamw_mod_local = sys.modules["unsloth.optimizers.q_galore_adamw"] _projector_mod_local = sys.modules["unsloth.optimizers.q_galore_projector"] - + _quantize = _projector_mod_local._quantize - + p = torch.nn.Parameter(torch.randn(16, 16)) group = {"weight_quant": True, "stochastic_round": False, "weight_group_size": 16} - + # Replicate the re-quantize logic at the end of optimizer step float_data = p.data.clone() q, scales, zeros, shape = _quantize(float_data, q_group_size=group["weight_group_size"]) - + # The key assertion: p.data stays float, _q_data holds uint8 p._q_data = q.to(p.data.device) p._q_scales = scales p._q_zeros = zeros p._q_shape = shape - + assert p.data.is_floating_point(), "p.data was converted to uint8!" assert p._q_data.dtype == torch.uint8, "_q_data should be uint8!" - diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 2b94c52755..bbc7d72e02 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -236,7 +236,7 @@ def create_optimizer(self): ) return self.optimizer - def _create_q_galore_optimizer(self, config: "QGaloreConfig", embedding_lr=None): + def _create_q_galore_optimizer(self, config: "QGaloreConfig", embedding_lr = None): """Build the Q-GaLore optimizer from a QGaloreConfig.""" from unsloth.optimizers.q_galore_adamw import ( QGaLoreAdamW8bit, From 3ca1af97c05b6502a63c36a973c34797452cd44a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Mar 2026 17:28:16 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utils/test_q_galore.py | 24 +++++++++++++++++------- unsloth/optimizers/q_galore_adamw.py | 4 +++- unsloth/trainer.py | 2 +- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 9f336a00c2..29812cc6bb 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -284,7 +284,7 @@ def test_empty_target_modules_no_galore(self): """target_modules=[] should result in no GaLore params.""" model = nn.Module() model.q_proj = nn.Linear(64, 64, bias = False) - + # Pass empty list, should NOT fall back to defaults groups = make_q_galore_param_groups( model, @@ -292,9 +292,11 @@ def test_empty_target_modules_no_galore(self): target_modules = [], weight_quant = False, ) - + galore_groups = [g for g in groups if "rank" in g] - assert len(galore_groups) == 0, "Expected no GaLore groups when target_modules=[]" + assert ( + len(galore_groups) == 0 + ), "Expected no GaLore groups when target_modules=[]" # ====================================================================== @@ -444,7 +446,7 @@ def test_weight_decay_uses_saved_data(self): # Create a mock parameter and group p = torch.nn.Parameter(torch.ones(4, 4)) p._saved_data = torch.ones(4, 4) * 2.0 # Pre-update weights - p.data = torch.ones(4, 4) * 3.0 # Post-update weights (GaLore output) + p.data = torch.ones(4, 4) * 3.0 # Post-update weights (GaLore output) group = {"weight_decay": 0.1, "lr": 1.0, "_wd_saved": 0.1} @@ -456,7 +458,9 @@ def test_weight_decay_uses_saved_data(self): # If it used p.data (wrong), value would be 3.0 - (1.0 * 0.1 * 3.0) = 2.7 # If it used p._saved_data (correct), value is 3.0 - (1.0 * 0.1 * 2.0) = 2.8 - assert torch.allclose(p.data, torch.tensor(2.8)), "Weight decay didn't use _saved_data!" + assert torch.allclose( + p.data, torch.tensor(2.8) + ), "Weight decay didn't use _saved_data!" def test_params_float_after_weight_quant_step(self): """After a step with weight_quant=True, parameters must remain floating point.""" @@ -466,11 +470,17 @@ def test_params_float_after_weight_quant_step(self): _quantize = _projector_mod_local._quantize p = torch.nn.Parameter(torch.randn(16, 16)) - group = {"weight_quant": True, "stochastic_round": False, "weight_group_size": 16} + group = { + "weight_quant": True, + "stochastic_round": False, + "weight_group_size": 16, + } # Replicate the re-quantize logic at the end of optimizer step float_data = p.data.clone() - q, scales, zeros, shape = _quantize(float_data, q_group_size=group["weight_group_size"]) + q, scales, zeros, shape = _quantize( + float_data, q_group_size = group["weight_group_size"] + ) # The key assertion: p.data stays float, _q_data holds uint8 p._q_data = q.to(p.data.device) diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index 0975b79b59..af168de080 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -339,7 +339,9 @@ def make_q_galore_param_groups( Returns: List of two param group dicts: ``[galore_group, non_galore_group]``. """ - targets = set(target_modules) if target_modules is not None else _DEFAULT_GALORE_TARGETS + targets = ( + set(target_modules) if target_modules is not None else _DEFAULT_GALORE_TARGETS + ) galore_params = [] non_galore_params = [] diff --git a/unsloth/trainer.py b/unsloth/trainer.py index bbc7d72e02..6248112926 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -269,7 +269,7 @@ def _create_q_galore_optimizer(self, config: "QGaloreConfig", embedding_lr = Non if embedding_lr is not None: # Build a fast param->name lookup (O(N) instead of O(N*M)) param_to_name = {id(p): name for name, p in self.model.named_parameters()} - + new_groups = [] for group in param_groups: if "rank" in group: From 16c296c2d7d2c032cd806b31b0e9bf466e2e378a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 24 Mar 2026 15:59:31 +0000 Subject: [PATCH 09/12] Fix 3 bugs in Q-GaLore optimizer and add weight_quant forward hooks 1. Fix use-after-delete crash: move `del p._saved_data` after the weight decay block so decoupled weight decay can reference the current weights correctly (p.data). 2. Fix substring matching in make_q_galore_param_groups: split parameter names on "." and check exact component matches to prevent false positives (e.g. "not_q_proj" matching "q_proj"). 3. Implement forward pre-hooks for weight_quant: after the optimizer quantizes weights to INT8, replace p.data with a 1-element placeholder to free float memory. A register_forward_pre_hook dequantizes back to float before each forward pass. The trainer calls install_weight_quant_hooks() when weight_quant is enabled. 4. Update test_weight_decay_uses_saved_data to match the fixed code path (decoupled decay uses p.data, expected value 2.7). Add test_weight_quant_hook_restores_float to verify the INT8-to-float hook round-trip. All 24/24 Q-GaLore tests pass. Benchmarked on Llama-3.2-1B-Instruct FFT: Q-GaLore saves 32% VRAM (10.63 -> 7.24 GB) with better loss convergence (1.3 vs 2.0 at step 100). No regressions in 31-notebook sweep across Llama, Qwen, Mistral, Phi, Gemma, vision, and GRPO. --- tests/utils/test_q_galore.py | 53 +++++++++++++++++++++++----- unsloth/optimizers/q_galore_adamw.py | 43 ++++++++++++++++++---- unsloth/trainer.py | 5 +++ 3 files changed, 86 insertions(+), 15 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 29812cc6bb..7d6f56ea3c 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -439,28 +439,29 @@ def test_optimizer_hyperparams_forwarded(self): assert "eps" in param_names, "eps not in QGaLoreAdamW8bit.__init__ params" def test_weight_decay_uses_saved_data(self): - """Weight decay should apply to the pre-updated weights, not post-updated.""" + """Weight decay should apply standard decoupled AdamW decay on current weights.""" _adamw_mod_local = sys.modules["unsloth.optimizers.q_galore_adamw"] - # We can test this logic without bitsandbytes by mocking the optimizer step # Create a mock parameter and group p = torch.nn.Parameter(torch.ones(4, 4)) p._saved_data = torch.ones(4, 4) * 2.0 # Pre-update weights - p.data = torch.ones(4, 4) * 3.0 # Post-update weights (GaLore output) + # Simulate project-back: p.data = p._saved_data + projected update + p.data = p._saved_data.add_(torch.ones(4, 4) * 1.0) # p.data is now 3.0 group = {"weight_decay": 0.1, "lr": 1.0, "_wd_saved": 0.1} - # Replicate the decoupled weight decay logic + # Replicate the fixed decoupled weight decay logic (uses p.data, not p._saved_data) p.data.add_( - p._saved_data, + p.data, alpha = -group["lr"] * group["_wd_saved"], ) - # If it used p.data (wrong), value would be 3.0 - (1.0 * 0.1 * 3.0) = 2.7 - # If it used p._saved_data (correct), value is 3.0 - (1.0 * 0.1 * 2.0) = 2.8 + del p._saved_data # Clean up after all uses, matching fixed code + + # Decoupled weight decay: 3.0 - (1.0 * 0.1 * 3.0) = 2.7 assert torch.allclose( - p.data, torch.tensor(2.8) - ), "Weight decay didn't use _saved_data!" + p.data, torch.tensor(2.7) + ), "Weight decay didn't use p.data for decoupled decay!" def test_params_float_after_weight_quant_step(self): """After a step with weight_quant=True, parameters must remain floating point.""" @@ -490,3 +491,37 @@ def test_params_float_after_weight_quant_step(self): assert p.data.is_floating_point(), "p.data was converted to uint8!" assert p._q_data.dtype == torch.uint8, "_q_data should be uint8!" + + def test_weight_quant_hook_restores_float(self): + """Forward pre-hook should dequantize INT8 weights before forward pass.""" + _adamw_mod_local = sys.modules["unsloth.optimizers.q_galore_adamw"] + _projector_mod_local = sys.modules["unsloth.optimizers.q_galore_projector"] + install_hook = _adamw_mod_local.install_weight_quant_hooks + + linear = nn.Linear(16, 8, bias=False) + original = linear.weight.data.clone() + + # Quantize the weight and replace with placeholder (simulates post-step) + q, scales, zeros, shape = _projector_mod_local._quantize( + linear.weight.data.clone(), q_group_size=16 + ) + linear.weight._q_data = q + linear.weight._q_scales = scales + linear.weight._q_zeros = zeros + linear.weight._q_shape = shape + linear.weight.data = torch.zeros(1, dtype=linear.weight.dtype) + assert linear.weight.data.numel() == 1, "placeholder should be 1 element" + + # Install hook and run forward -- should restore float weights + handles = install_hook(linear) + x = torch.randn(2, 16) + out = linear(x) # triggers pre-hook + + assert linear.weight.data.shape == (8, 16), "weight shape not restored" + assert linear.weight.data.is_floating_point(), "weight not float after hook" + # Check values are close to original (quantization introduces small error) + assert torch.allclose(linear.weight.data, original, atol=0.15), \ + "dequantized weight too far from original" + + for h in handles: + h.remove() diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index af168de080..2b6d12148e 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -26,7 +26,7 @@ _dequantize, ) -__all__ = ["QGaLoreAdamW8bit"] +__all__ = ["QGaLoreAdamW8bit", "install_weight_quant_hooks"] try: import bitsandbytes.functional as bnb_F @@ -194,17 +194,18 @@ def step(self, closure = None): if "rank" in group: # p.data now holds the weight update in low-rank space p.data = p._saved_data.add_(state["projector"].project_back(p.data)) - del p._saved_data # Re-apply decoupled weight decay using pre-update weights if "_wd_saved" in group: p.data.add_( - p._saved_data, + p.data, alpha = -group["lr"] * group["_wd_saved"], ) group["weight_decay"] = group["_wd_saved"] del group["_wd_saved"] + del p._saved_data + # --- Re-quantize weight to INT8 --- if has_weight_quant: float_data = p.data.clone() @@ -212,12 +213,14 @@ def step(self, closure = None): gsize = group.get("weight_group_size", 128) quant_fn = _quantize_stochastic if stochastic else _quantize q, scales, zeros, shape = quant_fn(float_data, q_group_size = gsize) - # Keep p.data as float for the next forward/backward pass. - # Store quantized representation in _q_data for compressed storage. p._q_data = q.to(p.data.device) p._q_scales = scales p._q_zeros = zeros p._q_shape = shape + # Replace p.data with a scalar placeholder to free float memory. + # A forward pre-hook (install_weight_quant_hooks) will + # dequantize back to float before the next forward pass. + p.data = torch.zeros(1, dtype=p.data.dtype, device=p.data.device) state["step"] += 1 @@ -277,6 +280,33 @@ def init_weight_quantization( p._weight_group_size = group_size +def _weight_quant_pre_hook(module, args): + """Forward pre-hook: dequantize INT8 weights to float before forward.""" + for p in module.parameters(recurse=False): + if hasattr(p, "_q_scales") and p._q_scales is not None: + float_weight = _dequantize( + p._q_data, p._q_scales, p._q_zeros, p._q_shape, + ) + p.data = float_weight.to(p.data.device) + + +def install_weight_quant_hooks(model: torch.nn.Module) -> list: + """Register forward pre-hooks on modules whose weights are INT8-quantized. + + Returns a list of hook handles so the caller can remove them if needed. + """ + handles = [] + for module in model.modules(): + has_quant_param = any( + hasattr(p, "_q_scales") + for p in module.parameters(recurse=False) + ) + if has_quant_param: + h = module.register_forward_pre_hook(_weight_quant_pre_hook) + handles.append(h) + return handles + + # ====================================================================== # Param-group construction helper # ====================================================================== @@ -353,7 +383,8 @@ def make_q_galore_param_groups( # Check if any target module name appears as a component in the param name. # Exclude 1-D parameters (biases, norms) because GaLoreProjector.project # requires 2-D gradients. - is_galore = param.dim() >= 2 and any(t in name for t in targets) + name_parts = name.split(".") + is_galore = param.dim() >= 2 and any(t in name_parts for t in targets) if is_galore: galore_params.append(param) diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 6248112926..eaf204e376 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -241,6 +241,7 @@ def _create_q_galore_optimizer(self, config: "QGaloreConfig", embedding_lr = Non from unsloth.optimizers.q_galore_adamw import ( QGaLoreAdamW8bit, make_q_galore_param_groups, + install_weight_quant_hooks, ) lr = self.args.learning_rate @@ -319,6 +320,10 @@ def _create_q_galore_optimizer(self, config: "QGaloreConfig", embedding_lr = Non group_size = config.weight_group_size, stochastic = config.stochastic_round, ) + # Forward pre-hooks dequantize INT8 weights to float before each + # forward pass, allowing the optimizer to free float weight memory + # between steps. + install_weight_quant_hooks(self.model) n_galore = sum(len(g["params"]) for g in param_groups if "rank" in g) n_other = sum(len(g["params"]) for g in param_groups if "rank" not in g) From 86725db592a627fed1544f8b3a705ab121f738ac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Mar 2026 04:33:58 +0000 Subject: [PATCH 10/12] Default weight_quant to False in QGaloreConfig Benchmarks show weight_quant=True adds ~1 GB on Llama-3.2-1B due to INT8 copy/scale overhead exceeding savings from the placeholder trick. Users can still opt in explicitly. The optimizer logic is unchanged. --- unsloth/optimizers/q_galore_adamw.py | 2 +- unsloth/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index 2b6d12148e..8c0aad5e80 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -333,7 +333,7 @@ def make_q_galore_param_groups( proj_quant: bool = True, proj_quant_group_size: int = -1, proj_quant_n_bit: int = 4, - weight_quant: bool = True, + weight_quant: bool = False, stochastic_round: bool = True, weight_group_size: int = 128, cos_threshold: float = 0.4, diff --git a/unsloth/trainer.py b/unsloth/trainer.py index eaf204e376..eea985e958 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -145,7 +145,7 @@ class QGaloreConfig: proj_quant: bool = True proj_quant_group_size: int = -1 proj_quant_n_bit: int = 4 - weight_quant: bool = True + weight_quant: bool = False stochastic_round: bool = True weight_group_size: int = 128 cos_threshold: float = 0.4 From 364ed9aeef602bf5788fd369394c7a4e378dac97 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Mar 2026 07:02:10 +0000 Subject: [PATCH 11/12] Optimize Q-GaLore projector and optimizer step performance Projector (q_galore_projector.py): - Use torch.svd_lowrank with oversampling p=10 (Halko et al. 2009) instead of full SVD for large matrices. Falls back to full SVD when min(m,n) <= 2*rank. SVD steps are 6-8x faster on Llama-3.2-1B (22s -> 3s for first step). - Cache the dequantized ortho matrix between project() and project_back() to avoid redundant dequantization when quant=True. - Replace F.cosine_similarity with torch.dot for 1-D unit vectors in the adaptive schedule. Remove unused torch.nn.functional import. - Use collections.deque(maxlen=queue_size) instead of list with manual pop(0). Optimizer (q_galore_adamw.py): - Remove redundant .clone() on dequantized weights (line 151) and on float data before re-quantization (line 211). _dequantize already returns a fresh tensor and _quantize/_quantize_stochastic only reads its input. - Consolidate per-group torch.cuda.synchronize() into a single call after all param groups complete. - Use torch.empty instead of torch.zeros for the scalar placeholder tensor that is never read. Verified: 24/24 unit tests pass. Llama-3.2-1B 61-step training produces losses within 0.24% relative diff (correlation >0.9999) of the original. --- unsloth/optimizers/q_galore_adamw.py | 13 ++----- unsloth/optimizers/q_galore_projector.py | 48 +++++++++++++----------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index 8c0aad5e80..4cb614819d 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -148,7 +148,7 @@ def step(self, closure = None): p._q_zeros, p._q_shape, ) - p.data = float_weight.clone().to(p.data.device) + p.data = float_weight # else: first step, weights are still float — skip dequantize # --- GaLore projection --- @@ -208,7 +208,7 @@ def step(self, closure = None): # --- Re-quantize weight to INT8 --- if has_weight_quant: - float_data = p.data.clone() + float_data = p.data stochastic = group.get("stochastic_round", True) gsize = group.get("weight_group_size", 128) quant_fn = _quantize_stochastic if stochastic else _quantize @@ -220,16 +220,11 @@ def step(self, closure = None): # Replace p.data with a scalar placeholder to free float memory. # A forward pre-hook (install_weight_quant_hooks) will # dequantize back to float before the next forward pass. - p.data = torch.zeros(1, dtype=p.data.dtype, device=p.data.device) + p.data = torch.empty(1, dtype=p.data.dtype, device=p.data.device) state["step"] += 1 - # Sync once per param group (not per-param) to avoid excessive - # GPU stalls while still ensuring bnb async kernels complete. - if torch.cuda.is_available(): - torch.cuda.synchronize() - - if self.is_paged and torch.cuda.is_available(): + if torch.cuda.is_available(): torch.cuda.synchronize() return loss diff --git a/unsloth/optimizers/q_galore_projector.py b/unsloth/optimizers/q_galore_projector.py index 94b8d7c105..7449bab680 100644 --- a/unsloth/optimizers/q_galore_projector.py +++ b/unsloth/optimizers/q_galore_projector.py @@ -16,8 +16,9 @@ # Original paper: "Q-GaLore: Quantized GaLore with INT4 Projection and # Layer-Adaptive Low-Rank Gradients" (arXiv:2407.08296) +from collections import deque + import torch -import torch.nn.functional as F __all__ = ["GaLoreProjector"] @@ -74,6 +75,7 @@ class GaLoreProjector: "past_ortho_vector", "queue", "svd_count", + "_ortho_float_cache", ) def __init__( @@ -104,8 +106,9 @@ def __init__( self.gamma_proj = gamma_proj self.queue_size = queue_size self.past_ortho_vector = None - self.queue: list = [] + self.queue = deque(maxlen=queue_size) self.svd_count = 0 + self._ortho_float_cache = None # Projection matrix state self.ortho_matrix = None @@ -144,8 +147,8 @@ def project(self, full_rank_grad: torch.Tensor, step: int) -> torch.Tensor: self._update_adaptive_schedule(float_ortho, side = "right") self._store_ortho(float_ortho) - float_ortho = self._load_ortho() - low_rank_grad = torch.matmul(full_rank_grad, float_ortho.t()) + self._ortho_float_cache = self._load_ortho() + low_rank_grad = torch.matmul(full_rank_grad, self._ortho_float_cache.t()) else: # "wide" matrix → left projection (Q^T @ grad) if self.ortho_matrix is None or step % self.update_proj_gap == 0: @@ -157,8 +160,8 @@ def project(self, full_rank_grad: torch.Tensor, step: int) -> torch.Tensor: self._update_adaptive_schedule(float_ortho, side = "left") self._store_ortho(float_ortho) - float_ortho = self._load_ortho() - low_rank_grad = torch.matmul(float_ortho.t(), full_rank_grad) + self._ortho_float_cache = self._load_ortho() + low_rank_grad = torch.matmul(self._ortho_float_cache.t(), full_rank_grad) return low_rank_grad @@ -171,7 +174,10 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: Returns: The full-rank update scaled by ``self.scale``. """ - float_ortho = self._load_ortho() + float_ortho = self._ortho_float_cache + self._ortho_float_cache = None + if float_ortho is None: + float_ortho = self._load_ortho() if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: full_rank_grad = torch.matmul(low_rank_grad, float_ortho) @@ -205,14 +211,19 @@ def _compute_orthogonal( matrix = weights.float() if original_dtype != torch.float32 else weights - U, s, Vh = torch.linalg.svd(matrix, full_matrices = False) + if side not in ("right", "left"): + raise ValueError(f"side must be 'left' or 'right', got '{side}'") - if side == "right": - result = Vh[:rank, :] - elif side == "left": - result = U[:, :rank] + m, n = matrix.shape + if min(m, n) <= rank * 2: + U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + result = Vh[:rank, :] if side == "right" else U[:, :rank] else: - raise ValueError(f"side must be 'left' or 'right', got '{side}'") + # Oversampling p=10 per Halko et al. 2009 (arXiv:0909.4061) + # recommendation of p=5..10 for large low-rank matrices. + q = min(rank + 10, min(m, n)) + U, s, V = torch.svd_lowrank(matrix, q=q, niter=2) + result = V[:, :rank].t() if side == "right" else U[:, :rank] if original_dtype != torch.float32: result = result.to(device = original_device, dtype = original_dtype) @@ -236,18 +247,13 @@ def _update_adaptive_schedule( current_vector = float_ortho[:, :1].flatten() if self.past_ortho_vector is not None: - cos_sim = F.cosine_similarity( - self.past_ortho_vector.unsqueeze(0), - current_vector.unsqueeze(0), - ).item() + cos_sim = torch.dot(self.past_ortho_vector, current_vector).item() - if len(self.queue) == self.queue_size: - self.queue.pop(0) self.queue.append(cos_sim) if ( - len(self.queue) == self.queue_size - and sum(self.queue) / self.queue_size >= self.cos_threshold + len(self.queue) == self.queue.maxlen + and sum(self.queue) / len(self.queue) >= self.cos_threshold ): self.update_proj_gap = int(self.update_proj_gap * self.gamma_proj) From f1dfacd2f5c624515e64ace26fa638c05b7d80c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Mar 2026 07:14:04 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utils/test_q_galore.py | 11 ++++++----- unsloth/optimizers/q_galore_adamw.py | 12 +++++++----- unsloth/optimizers/q_galore_projector.py | 6 +++--- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/utils/test_q_galore.py b/tests/utils/test_q_galore.py index 7d6f56ea3c..6dea5014a0 100644 --- a/tests/utils/test_q_galore.py +++ b/tests/utils/test_q_galore.py @@ -498,18 +498,18 @@ def test_weight_quant_hook_restores_float(self): _projector_mod_local = sys.modules["unsloth.optimizers.q_galore_projector"] install_hook = _adamw_mod_local.install_weight_quant_hooks - linear = nn.Linear(16, 8, bias=False) + linear = nn.Linear(16, 8, bias = False) original = linear.weight.data.clone() # Quantize the weight and replace with placeholder (simulates post-step) q, scales, zeros, shape = _projector_mod_local._quantize( - linear.weight.data.clone(), q_group_size=16 + linear.weight.data.clone(), q_group_size = 16 ) linear.weight._q_data = q linear.weight._q_scales = scales linear.weight._q_zeros = zeros linear.weight._q_shape = shape - linear.weight.data = torch.zeros(1, dtype=linear.weight.dtype) + linear.weight.data = torch.zeros(1, dtype = linear.weight.dtype) assert linear.weight.data.numel() == 1, "placeholder should be 1 element" # Install hook and run forward -- should restore float weights @@ -520,8 +520,9 @@ def test_weight_quant_hook_restores_float(self): assert linear.weight.data.shape == (8, 16), "weight shape not restored" assert linear.weight.data.is_floating_point(), "weight not float after hook" # Check values are close to original (quantization introduces small error) - assert torch.allclose(linear.weight.data, original, atol=0.15), \ - "dequantized weight too far from original" + assert torch.allclose( + linear.weight.data, original, atol = 0.15 + ), "dequantized weight too far from original" for h in handles: h.remove() diff --git a/unsloth/optimizers/q_galore_adamw.py b/unsloth/optimizers/q_galore_adamw.py index 4cb614819d..6cd0a4a846 100644 --- a/unsloth/optimizers/q_galore_adamw.py +++ b/unsloth/optimizers/q_galore_adamw.py @@ -220,7 +220,7 @@ def step(self, closure = None): # Replace p.data with a scalar placeholder to free float memory. # A forward pre-hook (install_weight_quant_hooks) will # dequantize back to float before the next forward pass. - p.data = torch.empty(1, dtype=p.data.dtype, device=p.data.device) + p.data = torch.empty(1, dtype = p.data.dtype, device = p.data.device) state["step"] += 1 @@ -277,10 +277,13 @@ def init_weight_quantization( def _weight_quant_pre_hook(module, args): """Forward pre-hook: dequantize INT8 weights to float before forward.""" - for p in module.parameters(recurse=False): + for p in module.parameters(recurse = False): if hasattr(p, "_q_scales") and p._q_scales is not None: float_weight = _dequantize( - p._q_data, p._q_scales, p._q_zeros, p._q_shape, + p._q_data, + p._q_scales, + p._q_zeros, + p._q_shape, ) p.data = float_weight.to(p.data.device) @@ -293,8 +296,7 @@ def install_weight_quant_hooks(model: torch.nn.Module) -> list: handles = [] for module in model.modules(): has_quant_param = any( - hasattr(p, "_q_scales") - for p in module.parameters(recurse=False) + hasattr(p, "_q_scales") for p in module.parameters(recurse = False) ) if has_quant_param: h = module.register_forward_pre_hook(_weight_quant_pre_hook) diff --git a/unsloth/optimizers/q_galore_projector.py b/unsloth/optimizers/q_galore_projector.py index 7449bab680..cabd228b92 100644 --- a/unsloth/optimizers/q_galore_projector.py +++ b/unsloth/optimizers/q_galore_projector.py @@ -106,7 +106,7 @@ def __init__( self.gamma_proj = gamma_proj self.queue_size = queue_size self.past_ortho_vector = None - self.queue = deque(maxlen=queue_size) + self.queue = deque(maxlen = queue_size) self.svd_count = 0 self._ortho_float_cache = None @@ -216,13 +216,13 @@ def _compute_orthogonal( m, n = matrix.shape if min(m, n) <= rank * 2: - U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + U, s, Vh = torch.linalg.svd(matrix, full_matrices = False) result = Vh[:rank, :] if side == "right" else U[:, :rank] else: # Oversampling p=10 per Halko et al. 2009 (arXiv:0909.4061) # recommendation of p=5..10 for large low-rank matrices. q = min(rank + 10, min(m, n)) - U, s, V = torch.svd_lowrank(matrix, q=q, niter=2) + U, s, V = torch.svd_lowrank(matrix, q = q, niter = 2) result = V[:, :rank].t() if side == "right" else U[:, :rank] if original_dtype != torch.float32: