feat: add torchao int4, nf4, int8 LoRA QLoRA support#3417
Conversation
970b2a6 to
33c495d
Compare
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR introduces TorchAO as a structured quantization backend alternative to bitsandbytes for QLoRA training, unifies dequantization logic across both backends, migrates from ChangesTorchAO Quantization Backend Integration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Land the contributor PR after rebasing onto main and closing the gaps found during audit. Behavior changes - peft.backend: torchao + weight_dtype: int8 now stays as adapter=lora (matching bnb int8 semantics) instead of being auto-promoted to qlora. - Unsupported torchao weight_dtypes (fp8, nvfp4, mxfp4) are rejected at validation with a clear pointer to the QAT/PTQ flow. - Merging a torchao adapter requires merge_method=legacy; the memory-efficient merger simulates bnb NF4 and would silently mis-merge torchao tensor subclasses. - DoRA paths in kernels/lora.py route through dequantize_weight so DoRA + torchao works end-to-end (the previous bare dequantize calls would have failed on AffineQuantizedTensor / NF4Tensor). Bug fixes uncovered while landing - model.py: switch from the deprecated string quant_type API (TorchAoConfig(quant_type="int4_weight_only")) to the object-based Int4WeightOnlyConfig / Int8WeightOnlyConfig API required by modern transformers. - model.py: import NF4WeightOnlyConfig from torchao.prototype._nf4tensor_api (with a fallback to the old torchao.dtypes path) — the original location no longer exists in torchao >= 0.13. - model.py: NF4WeightOnlyConfig now takes no constructor arguments; set block_size / scaler_block_size as attributes. Coverage - ModelLoader.is_torchao_qlora now matches both adapter=lora and adapter=qlora to keep the bnb-skipping branches consistent for the int8 case. - model.py's _set_quantization_config branch now triggers for adapter in (lora, qlora) so int8 torchao gets its TorchAoConfig. Docs + examples - docs/qlora_torchao.qmd: new page covering backends, weight_dtype table, constraints, FSDP2. - examples/llama-3/qlora-torchao.yaml: minimal config using the new peft block. Tests - tests/utils/lora/test_config_validation_lora.py: torchao+int8 stays lora; fp8/nvfp4/mxfp4 rejected; merge_lora requires legacy; DoRA + torchao allowed. - tests/test_loaders.py: TorchAoConfig is wired with Int4WeightOnlyConfig / Int8WeightOnlyConfig / NF4WeightOnlyConfig. - tests/e2e/kernels/test_quantize.py: dequantize_weight against fake AffineQuantizedTensor / NF4Tensor subclasses (no CUDA needed). Validated locally with a CUDA smoke test on SmolLM2-135M: torchao int8 LoRA loads with AffineQuantizedTensor base weights, forward + backward produce gradients on all 420 trainable params.
Earlier pass rejected fp8/nvfp4/mxfp4 at the schema layer, telling users to use QAT/PTQ instead. That was wrong: - NVFP4 has a real weight-only torchao config (NVFP4WeightOnlyConfig in torchao.prototype.mx_formats) — it's a 4-bit quant, perfectly suited to QLoRA. Now auto-promotes adapter lora -> qlora and builds NVFP4WeightOnlyConfig at load. - FP8 (float8_e4m3fn) has Float8WeightOnlyConfig in torchao.quantization — a one-byte-per-weight quant that mirrors INT8's role. Keeps adapter as lora. - MXFP4 is the genuine 'no weight-only flavor' case. The schema now passes it through; the loader raises with a pointer to quantize_moe_experts: true for MoE models (which is where MXFP4 LoRA actually lives, via the ScatterMoE-LoRA path landed in #3663) and to qat/ptq for inference-time MXFP4. CUDA smoke-tested on SmolLM2-135M: - weight_dtype: fp8 -> Float8WeightOnlyConfig, forward+backward OK - weight_dtype: nvfp4 (group_size=16) -> NVFP4WeightOnlyConfig, OK - weight_dtype: mxfp4 -> loader error pointing to quantize_moe_experts Docs and the dtype table updated; schema/loader tests extended.
…overrides peft.backend: torchao installs a single TorchAoConfig covering every linear layer. It composes badly with axolotl's other quant mechanisms, but the prior code silently picked a winner: - model_quantization_config (Mxfp4Config / FineGrainedFP8Config) gets overwritten by our TorchAoConfig later in the same function. - A checkpoint with embedded quantization_config (gpt-oss MXFP4, pre-quantized AWQ / GPTQ / BNB) wins via the earlier if-branch in _set_quantization_config; peft.backend is silently ignored. - quantize_moe_experts: true would race with TorchAoConfig over the same expert tensors. - gptq: true is a separate path entirely. Now: - SystemValidationMixin.check_torchao_backend_exclusivity rejects peft.backend: torchao + (model_quantization_config | quantize_moe_experts | gptq) at validation with a pointer to docs/qlora_torchao.qmd. - ModelLoader._set_quantization_config raises when the base model's checkpoint already advertises a quant_method and peft.backend is also set (the conflict only resolvable post-load_model_config). Documents the boundary: mixed-quant flows (experts MXFP4 + attention bf16, gpt-oss-style) drop peft.backend and use the per-mechanism config (quantize_moe_experts or the checkpoint's quant_method) directly. peft.backend is for uniform base-quant only.
…he awq/gptq/bnb branch The prior loader-time check sat inside the peft.backend elif branch, so checkpoints with quant_method in (awq, gptq, bitsandbytes) hit the earlier if-branch first and silently overwrote model_kwargs's quantization_config — peft.backend got dropped on the floor. Move the check to the top of _set_quantization_config so it fires for any non-empty model_config.quantization_config, including the realistic ones that motivated this audit: - gpt-oss native MXFP4 (quant_method: mxfp4) - AMD Quark MXFP4 with a per-module exclude list, e.g. amd/Kimi-K2.6-MXFP4: experts in MXFP4, ~305 modules excluded (lm_head, every attention projection, vision tower, mm_projector) - AWQ / GPTQ / bitsandbytes pre-quantized checkpoints Tests parametrize across all five quant_methods. Docs gain a section naming AMD Quark MXFP4 as the canonical mixed-quant example and restate the recommendation: drop peft.backend so the checkpoint's own quantization_config flows through unchanged.
…ation_config
Replace the peft.backend / peft.weight_dtype / peft.group_size shape
with a structured discriminator on the existing
model_quantization_config field. One namespace for all base-model
quant; peft.backend drops out entirely.
User-facing surface:
# bnb 4-bit QLoRA (replaces adapter: qlora + load_in_4bit: true)
adapter: lora
model_quantization_config:
bnb:
weight_dtype: nf4
# torchao QLoRA
adapter: lora
model_quantization_config:
torchao:
weight_dtype: int4
# group_size: 128
# Legacy string form (Mxfp4Config / FineGrainedFP8Config) keeps working
# via the same field. Equivalent structured form:
model_quantization_config:
mxfp4:
config_kwargs: {}
Schema:
ModelQuantizationConfig(BaseModel) is a discriminated union with
exactly one of bnb / torchao / mxfp4 / fp8 set. The top-level field
accepts Literal["Mxfp4Config", "FineGrainedFP8Config"] |
ModelQuantizationConfig | None.
Auto-promotion:
Moved out of the peft block into LoraConfig.auto_detect_qlora, which
reads the structured form. bnb.nf4 sets load_in_4bit and promotes
lora -> qlora; bnb.int8 sets load_in_8bit; torchao 4-bit dtypes
(int4/nf4/nvfp4) promote lora -> qlora; torchao int8/fp8 stay as
weight-only LoRA.
Conflict surfaces (validation.py + model.py) updated to gate on
model_quantization_config.torchao instead of peft.backend:
- + quantize_moe_experts: true -> rejected at validation
- + gptq: true -> rejected at validation
- + load_in_4bit / load_in_8bit -> rejected at validation
- + checkpoint with embedded quant_method -> rejected at load time
(covers Quark, mxfp4, awq, gptq, bnb — the AMD Kimi-K2.6-MXFP4 case
Wing called out).
Internals:
axolotl.utils.config.validate_config returns nested fields as dicts
via model_dump. Two helpers in loaders/model.py (_mqc_branch and
_torchao_subconfig) accept either form so direct-Pydantic test
construction and post-validate dict access both work.
Docs (docs/qlora_torchao.qmd) and the example
(examples/llama-3/qlora-torchao.yaml) rewritten around the new shape.
Schema tests (24) and loader tests (43) rewritten to exercise it.
CUDA-validated on SmolLM2-135M for torchao int8 and fp8 paths: model
loads with the structured config, LoRA injects, forward+backward
produces gradients on all 420 trainable params.
Per maintainer feedback: "qlora is just lora with an nf4 base weight" —
don't carry it as a distinct adapter name. Demote `adapter: qlora` to
`adapter: lora` in the validator and key all internal "is this
QLoRA?" decisions off the actual base-weight quant state instead of
the adapter name.
User surface
- The recommended shape is now uniformly `adapter: lora` plus one of:
`model_quantization_config: {bnb: {weight_dtype: nf4}}` (terse)
`load_in_4bit: true` (legacy bnb)
`model_quantization_config: {torchao: {weight_dtype: int4}}`
- Legacy `adapter: qlora` configs keep working unchanged: a new
`normalize_adapter_qlora` validator demotes them and, if no
base-quant choice was spelled out, auto-sets `load_in_4bit: true`
(the legacy shorthand's implicit meaning). Emits a DEPRECATED log
with the migration path.
- `adapter: qlora` + `load_in_8bit: true` is now rejected as
ambiguous (QLoRA is a 4-bit thing).
Codebase
- `is_qlora_and_fsdp_enabled` now keys off `load_in_4bit`, not the
adapter name.
- The bnb 4-bit branch in `_set_quantization_config` fires on
`adapter == lora and load_in_4bit` (the validator guarantees
qlora is always normalized to lora upstream).
- `AxolotlTrainingArguments.qlora` (dead read; nothing in axolotl
consumes it) is set from `load_in_4bit` instead of the adapter name.
- `validate_qlora` (mode=after) no longer gates the merge-bans on
`adapter == qlora`; merge into an 8-bit/4-bit/GPTQ base is rejected
regardless of how the user spelled the quant.
- `merge_lora` CLI's `simulate_nf4` drops the dead
`_original_adapter == "qlora"` check; `_original_load_in_4bit`
covers all bnb-4-bit cases (and the validator sets it for
legacy-qlora configs upstream).
Deprecation warnings (per request)
- `adapter: qlora` → DEPRECATED log pointing at the new shape.
- `model_quantization_config: "Mxfp4Config" | "FineGrainedFP8Config"`
(string form) → DEPRECATED log naming the equivalent structured
form (with `config_kwargs` carried over).
- Both fire at config-load time, so they show up on every `axolotl
train` / `axolotl merge-lora` / `axolotl preprocess` invocation
for the affected configs.
Tests
- 69 passing. Legacy `adapter: qlora` test cases updated to assert
the demoted shape (`adapter == "lora"`, `load_in_4bit == True`).
- The `adapter: qlora` parametrize axis in
`test_set_quantization_config` removed; legacy paths exercised
via the validator's normalization.
…ig first, load_in_4bit/8bit deprecated
Previous commit kept load_in_4bit / load_in_8bit as the user-facing
knobs and set them internally to drive the bnb loader branch. Per
Wing: "i thought we were trying to get rid of load_in_4bit: true?"
— right. The structured form is the source of truth; the legacy
flags are deprecated user inputs.
One `normalize_base_quant_inputs` validator now does both halves
in lockstep (separate validators ran out of order in Pydantic v2,
which is what was leaving load_in_4bit unset for bare `adapter:
qlora` configs):
1. Translate every legacy spelling into the canonical structured
form:
- `adapter: qlora` → adapter: lora + bnb nf4
- `adapter: qlora` + load_in_4bit → adapter: lora + bnb nf4
- `load_in_4bit: true` (alone) → bnb nf4
- `load_in_8bit: true` (alone) → bnb int8
Emits a DEPRECATED warning on each path.
2. Mirror the structured form back into load_in_4bit/8bit so the
downstream loader code that still reads them sees a consistent
state.
`load_in_4bit` / `load_in_8bit` field descriptions now begin with
`DEPRECATED:` so the JSON schema and autogen docs flag them.
Docs (docs/qlora_torchao.qmd) and the headline example
(examples/llama-3/qlora-torchao.yaml) reworked so the canonical form
is the only shape shown; legacy forms appear only in the Deprecations
section with their migration target.
Tests: 59 passing — including bare `adapter: qlora` (no load_in_4bit
written) which validates through to load_in_4bit=True via the
combined validator. CUDA smoke (torchao int8 on SmolLM2-135M) still
loads, forward+backward, grads on 420 params.
1dc14fa to
4b23b3e
Compare
State of this branchRebased onto current Canonical config shapeadapter: lora
model_quantization_config:
bnb: # or torchao / mxfp4 / fp8
weight_dtype: nf4
Deprecations (kept working, warn at config-load)
A single Conflict surfaces
The load-time check covers every Kernels
Other rough edges closed during the audit
Verified
Known follow-up (not in this PR)Migrate the loader to read |
|
📖 Documentation Preview: https://6a1bb5ce49528c402bf6562e--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit f9f280b |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (2)
tests/test_loaders.py (1)
110-143: ⚡ Quick winUse
adapter="lora"in the torchao loader tests.These cases still build loader state with
adapter="qlora", but the schema normalizer is supposed to demote that upstream. Keeping the tests onqlorameans they don't cover the runtime contract that the torchao helpers actually key on.Also applies to: 158-168, 229-237, 277-287
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_loaders.py` around lines 110 - 143, The tests (e.g., test_set_quantization_config_torchao_qlora) are using adapter="qlora" but the torchao loader helpers expect the normalized upstream value "lora"; update the parametrized adapter values and any hardcoded self.cfg.adapter assignments in the torchao-related tests (including the other occurrences flagged at the other ranges) from "qlora" to "lora" so the tests exercise the runtime contract the loader keys on (adjust the parametrization tuple entries and the self.cfg.adapter assignments in those test methods).src/axolotl/utils/schemas/validation.py (1)
1232-1235: ⚡ Quick winTrim these new comments to one-line WHY notes.
These blocks are mostly describing WHAT the validators do and exceed the repo’s comment style for
src/axolotl/**.As per coding guidelines, "Only add comments when explaining the WHY behind non-obvious logic, hidden constraints, or workarounds for specific bugs. Do not comment on WHAT code does ... Comments should be a maximum of one short line."
Also applies to: 1248-1250, 1274-1282
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/utils/schemas/validation.py` around lines 1232 - 1235, Trim the multi-line explanatory comments around the structured "bnb" form and legacy checks to a single short WHY note: explain why the structured bnb case is exempt from the legacy Mxfp4Config/FineGrainedFP8Config string-form check (i.e., because auto_detect_qlora sets load_in_4bit/load_in_8bit), and replace similar WHAT-style blocks at the subsequent comment sites (the blocks that mention auto_detect_qlora, Mxfp4Config, FineGrainedFP8Config, load_in_4bit/load_in_8bit and the mxfp4/fp8 branches) with one-line WHY comments so the validation logic remains documented but conforms to the repo style.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/monkeypatch/peft/utils.py`:
- Around line 97-107: Save the original
peft.tuners.lora.torchao.dispatch_torchao before replacing it (e.g., store it on
peft_torchao._axolotl_orig_dispatch) and change patch_peft_torchao_dispatch to
install a patched_dispatch that only returns None for non-INT8 adapters but
delegates to the saved original for INT8/torchao dequantization cases (inspect
lora_config or adapter_name to detect INT8); keep the _axolotl_patched flag but
do not permanently drop the original dispatcher, or alternatively restore
peft_torchao.dispatch_torchao back to the saved original after the adapter load
completes, ensuring no leak across later adapter loads and preserving the
original behavior for INT8 TorchaoLoraLinear.
In `@src/axolotl/utils/schemas/model.py`:
- Around line 38-46: The schema currently lists "mxfp4" as an allowed value for
the weight_dtype Field but the loader rejects it later; update the Pydantic
model in model.py to enforce this at validation by removing "mxfp4" from the
Literal type for weight_dtype (i.e., change
Literal["int4","nf4","nvfp4","int8","fp8","mxfp4"] to exclude "mxfp4") or
alternatively add a Pydantic validator on weight_dtype that raises a ValueError
when the value == "mxfp4"; reference the weight_dtype Field in model.py to
implement the change so invalid YAML is rejected at schema validation time.
In `@src/axolotl/utils/schemas/peft.py`:
- Around line 249-259: The mirror step in src/axolotl/utils/schemas/peft.py
currently uses data.setdefault(...) so existing conflicting legacy flags can
remain; update the block that reads mqc = data.get("model_quantization_config")
and inspects bnb/weight_dtype to explicitly set data["load_in_4bit"] =
True/False and data["load_in_8bit"] = True/False according to weight_dtype
(e.g., if weight_dtype == "nf4" set load_in_4bit True and load_in_8bit False; if
"int8" set load_in_8bit True and load_in_4bit False; otherwise ensure both are
False or unset), so the canonical bnb config always wins over legacy flags used
by downstream loaders.
In `@src/axolotl/utils/schemas/quantization.py`:
- Around line 19-20: The validator validate_ao_dtype currently maps "nf4" →
TorchAOQuantDType.nf4 and is reused for both activation_dtype and weight_dtype,
so activation_dtype: "nf4" incorrectly passes; fix this by splitting the logic
into two validators (e.g., validate_weight_ao_dtype and
validate_activation_ao_dtype) or by adding a field-specific check: keep "nf4"
allowed for weight_dtype but explicitly reject it for activation_dtype by
raising a ValueError when the input is "nf4"; update the QATConfig/PTQConfig
model to use the new activation-specific validator for activation_dtype and the
weight-specific validator for weight_dtype instead of the single shared
validate_ao_dtype.
---
Nitpick comments:
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1232-1235: Trim the multi-line explanatory comments around the
structured "bnb" form and legacy checks to a single short WHY note: explain why
the structured bnb case is exempt from the legacy
Mxfp4Config/FineGrainedFP8Config string-form check (i.e., because
auto_detect_qlora sets load_in_4bit/load_in_8bit), and replace similar
WHAT-style blocks at the subsequent comment sites (the blocks that mention
auto_detect_qlora, Mxfp4Config, FineGrainedFP8Config, load_in_4bit/load_in_8bit
and the mxfp4/fp8 branches) with one-line WHY comments so the validation logic
remains documented but conforms to the repo style.
In `@tests/test_loaders.py`:
- Around line 110-143: The tests (e.g.,
test_set_quantization_config_torchao_qlora) are using adapter="qlora" but the
torchao loader helpers expect the normalized upstream value "lora"; update the
parametrized adapter values and any hardcoded self.cfg.adapter assignments in
the torchao-related tests (including the other occurrences flagged at the other
ranges) from "qlora" to "lora" so the tests exercise the runtime contract the
loader keys on (adjust the parametrization tuple entries and the
self.cfg.adapter assignments in those test methods).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 875be089-d749-4775-a411-98d1b242264f
📒 Files selected for processing (19)
_quarto.ymldocs/qlora_torchao.qmdexamples/llama-3/qlora-torchao.yamlsrc/axolotl/cli/merge_lora.pysrc/axolotl/core/builders/causal.pysrc/axolotl/kernels/lora.pysrc/axolotl/kernels/quantize.pysrc/axolotl/loaders/adapter.pysrc/axolotl/loaders/model.pysrc/axolotl/loaders/patch_manager.pysrc/axolotl/monkeypatch/peft/utils.pysrc/axolotl/utils/schemas/enums.pysrc/axolotl/utils/schemas/model.pysrc/axolotl/utils/schemas/peft.pysrc/axolotl/utils/schemas/quantization.pysrc/axolotl/utils/schemas/validation.pytests/e2e/kernels/test_quantize.pytests/test_loaders.pytests/utils/lora/test_config_validation_lora.py
| if getattr(peft_torchao, "_axolotl_patched", False): | ||
| return | ||
|
|
||
| def patched_dispatch(target, adapter_name, lora_config, **kwargs): | ||
| # Return None so PEFT falls back to standard Linear LoRA layers. | ||
| # Our LoRA kernels handle torchao dequantization explicitly. | ||
| return None | ||
|
|
||
| peft_torchao.dispatch_torchao = patched_dispatch | ||
| peft_torchao._axolotl_patched = True | ||
| LOG.info("Patched PEFT dispatch_torchao to skip TorchaoLoraLinear") |
There was a problem hiding this comment.
Restore/scope the PEFT dispatch_torchao monkeypatch to prevent leaking across later adapter loads.
patch_peft_torchao_dispatch() replaces peft.tuners.lora.torchao.dispatch_torchao with a stub that always returns None and never restores the original dispatcher; if a non-INT8 torchao adapter is loaded first, later INT8 torchao loads in the same worker will still observe the stub and skip PEFT’s TorchaoLoraLinear dispatch even though src/axolotl/loaders/adapter.py avoids calling the patch for INT8. Preserve the original dispatcher and restore it after the relevant load, or have the stub delegate back to the original for INT8.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/monkeypatch/peft/utils.py` around lines 97 - 107, Save the
original peft.tuners.lora.torchao.dispatch_torchao before replacing it (e.g.,
store it on peft_torchao._axolotl_orig_dispatch) and change
patch_peft_torchao_dispatch to install a patched_dispatch that only returns None
for non-INT8 adapters but delegates to the saved original for INT8/torchao
dequantization cases (inspect lora_config or adapter_name to detect INT8); keep
the _axolotl_patched flag but do not permanently drop the original dispatcher,
or alternatively restore peft_torchao.dispatch_torchao back to the saved
original after the adapter load completes, ensuring no leak across later adapter
loads and preserving the original behavior for INT8 TorchaoLoraLinear.
| weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8", "mxfp4"] = Field( | ||
| json_schema_extra={ | ||
| "description": ( | ||
| "torchao base-weight dtype. int4/nf4/nvfp4 → QLoRA; int8/fp8 " | ||
| "→ weight-only LoRA; mxfp4 is unsupported as a base-quant " | ||
| "shorthand (use quantize_moe_experts for MoE MXFP4)." | ||
| ) | ||
| } | ||
| ) |
There was a problem hiding this comment.
Reject torchao.mxfp4 at the schema layer.
"mxfp4" is advertised as a valid torchao.weight_dtype here, but the loader always rejects it later. That lets an invalid YAML pass validation and only fail during model load, which is the wrong layer for this check.
♻️ Proposed fix
- weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8", "mxfp4"] = Field(
+ weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8"] = Field(
json_schema_extra={
"description": (
"torchao base-weight dtype. int4/nf4/nvfp4 → QLoRA; int8/fp8 "
- "→ weight-only LoRA; mxfp4 is unsupported as a base-quant "
- "shorthand (use quantize_moe_experts for MoE MXFP4)."
+ "→ weight-only LoRA."
)
}
)As per coding guidelines, "Config validation must use Pydantic schemas defined in src/axolotl/utils/schemas/."
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8", "mxfp4"] = Field( | |
| json_schema_extra={ | |
| "description": ( | |
| "torchao base-weight dtype. int4/nf4/nvfp4 → QLoRA; int8/fp8 " | |
| "→ weight-only LoRA; mxfp4 is unsupported as a base-quant " | |
| "shorthand (use quantize_moe_experts for MoE MXFP4)." | |
| ) | |
| } | |
| ) | |
| weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8"] = Field( | |
| json_schema_extra={ | |
| "description": ( | |
| "torchao base-weight dtype. int4/nf4/nvfp4 → QLoRA; int8/fp8 " | |
| "→ weight-only LoRA." | |
| ) | |
| } | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/utils/schemas/model.py` around lines 38 - 46, The schema
currently lists "mxfp4" as an allowed value for the weight_dtype Field but the
loader rejects it later; update the Pydantic model in model.py to enforce this
at validation by removing "mxfp4" from the Literal type for weight_dtype (i.e.,
change Literal["int4","nf4","nvfp4","int8","fp8","mxfp4"] to exclude "mxfp4") or
alternatively add a Pydantic validator on weight_dtype that raises a ValueError
when the value == "mxfp4"; reference the weight_dtype Field in model.py to
implement the change so invalid YAML is rejected at schema validation time.
| # Step 2: mirror the structured form back into load_in_4bit / | ||
| # load_in_8bit for downstream loader compat. | ||
| mqc = data.get("model_quantization_config") | ||
| if isinstance(mqc, dict): | ||
| bnb = mqc.get("bnb") | ||
| if isinstance(bnb, dict): | ||
| weight_dtype = bnb.get("weight_dtype") | ||
| if weight_dtype == "nf4": | ||
| data.setdefault("load_in_4bit", True) | ||
| elif weight_dtype == "int8": | ||
| data.setdefault("load_in_8bit", True) |
There was a problem hiding this comment.
Force the mirrored load_in_*bit flags to match the canonical bnb config.
Using setdefault() here leaves conflicting legacy flags untouched. A config like model_quantization_config: {bnb: {weight_dtype: nf4}} plus load_in_4bit: false or load_in_8bit: true will pass validation, but the loader still keys off load_in_*bit and can take the wrong branch.
♻️ Proposed fix
if isinstance(mqc, dict):
bnb = mqc.get("bnb")
if isinstance(bnb, dict):
weight_dtype = bnb.get("weight_dtype")
if weight_dtype == "nf4":
- data.setdefault("load_in_4bit", True)
+ data["load_in_4bit"] = True
+ data["load_in_8bit"] = False
elif weight_dtype == "int8":
- data.setdefault("load_in_8bit", True)
+ data["load_in_8bit"] = True
+ data["load_in_4bit"] = FalseAs per coding guidelines, "Config validation must use Pydantic schemas defined in src/axolotl/utils/schemas/."
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Step 2: mirror the structured form back into load_in_4bit / | |
| # load_in_8bit for downstream loader compat. | |
| mqc = data.get("model_quantization_config") | |
| if isinstance(mqc, dict): | |
| bnb = mqc.get("bnb") | |
| if isinstance(bnb, dict): | |
| weight_dtype = bnb.get("weight_dtype") | |
| if weight_dtype == "nf4": | |
| data.setdefault("load_in_4bit", True) | |
| elif weight_dtype == "int8": | |
| data.setdefault("load_in_8bit", True) | |
| # Step 2: mirror the structured form back into load_in_4bit / | |
| # load_in_8bit for downstream loader compat. | |
| mqc = data.get("model_quantization_config") | |
| if isinstance(mqc, dict): | |
| bnb = mqc.get("bnb") | |
| if isinstance(bnb, dict): | |
| weight_dtype = bnb.get("weight_dtype") | |
| if weight_dtype == "nf4": | |
| data["load_in_4bit"] = True | |
| data["load_in_8bit"] = False | |
| elif weight_dtype == "int8": | |
| data["load_in_8bit"] = True | |
| data["load_in_4bit"] = False |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/utils/schemas/peft.py` around lines 249 - 259, The mirror step in
src/axolotl/utils/schemas/peft.py currently uses data.setdefault(...) so
existing conflicting legacy flags can remain; update the block that reads mqc =
data.get("model_quantization_config") and inspects bnb/weight_dtype to
explicitly set data["load_in_4bit"] = True/False and data["load_in_8bit"] =
True/False according to weight_dtype (e.g., if weight_dtype == "nf4" set
load_in_4bit True and load_in_8bit False; if "int8" set load_in_8bit True and
load_in_4bit False; otherwise ensure both are False or unset), so the canonical
bnb config always wins over legacy flags used by downstream loaders.
| if v == "nf4": | ||
| return TorchAOQuantDType.nf4 |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does torchao 0.17.0 support NF4 as an activation quantization dtype, or is NF4 weight-only?
💡 Result:
In torchao 0.17.0, NF4 is exclusively a weight-only quantization data type [1][2][3]. It is not supported for activation quantization [4]. Torchao provides NF4 support primarily through the nf4_weight_only() configuration and NF4Tensor subclass, which are designed to reduce memory usage by storing weights in the 4-bit Normal Float (NF4) format [1][2]. During computation, these weights are typically dequantized to a higher precision (such as bfloat16) [5]. Conversely, torchao's activation quantization schemes—such as those using Float8, NVFP4, or MXFP4—are explicitly categorized as dynamic or static activation quantization configurations, none of which utilize NF4 for the activations themselves [4][6][7].
Citations:
- 1: [NF4] Add
quantize_()API support for NF4 pytorch/ao#1216 - 2: https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py
- 3: https://github.com/pytorch-labs/ao/blob/b9beaf351e27133d189b57d6fa725b1a7824a457/torchao/dtypes/nf4tensor.py
- 4: https://docs.pytorch.org/ao/stable/workflows/inference.html
- 5: https://dev.to/sujankoirala021/i-fine-tuned-gemma-4-on-an-emotion-dataset-using-a-single-gpu-e72
- 6: https://docs.pytorch.org/ao/stable/contributing/quantization_overview.html
- 7: https://pytorch.org/blog/quantization-aware-training-in-torchao-ii/
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the target file around the mentioned lines and the validator definition.
FILE="src/axolotl/utils/schemas/quantization.py"
echo "== $FILE =="
nl -ba "$FILE" | sed -n '1,120p'
# Find where validate_ao_dtype is referenced within this schemas folder.
echo
echo "== References to validate_ao_dtype =="
rg -n "validate_ao_dtype" src/axolotl/utils/schemas | head -n 50
# Show the schema classes / fields that call into validate_ao_dtype (activation_dtype / weight_dtype).
echo
echo "== Where activation_dtype / weight_dtype are validated =="
rg -n "activation_dtype|weight_dtype" src/axolotl/utils/schemas | head -n 80Repository: axolotl-ai-cloud/axolotl
Length of output: 177
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="src/axolotl/utils/schemas/quantization.py"
echo "== git ls-files =="
git ls-files "$FILE" || true
echo
echo "== $FILE (first 200 lines) =="
if [ -f "$FILE" ]; then
sed -n '1,200p' "$FILE" | sed -n '1,200p'
else
echo "File not found: $FILE"
exit 0
fi
echo
echo "== validate_ao_dtype references =="
rg -n "validate_ao_dtype" src/axolotl/utils/schemas | head -n 50
echo
echo "== activation_dtype / weight_dtype validations =="
rg -n "(activation_dtype|weight_dtype)" src/axolotl/utils/schemas | head -n 80Repository: axolotl-ai-cloud/axolotl
Length of output: 5228
Disallow nf4 for activation_dtype (schema currently permits it)
src/axolotl/utils/schemas/quantization.py maps "nf4" → TorchAOQuantDType.nf4 in validate_ao_dtype(), and that same validator is attached to both activation_dtype and weight_dtype in QATConfig/PTQConfig, so activation_dtype: nf4 passes Pydantic validation. Since torchao 0.17.0 treats NF4 as weight-only (not an activation quantization dtype), reject "nf4" for activation_dtype (split validators or add field-specific constraints).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/utils/schemas/quantization.py` around lines 19 - 20, The
validator validate_ao_dtype currently maps "nf4" → TorchAOQuantDType.nf4 and is
reused for both activation_dtype and weight_dtype, so activation_dtype: "nf4"
incorrectly passes; fix this by splitting the logic into two validators (e.g.,
validate_weight_ao_dtype and validate_activation_ao_dtype) or by adding a
field-specific check: keep "nf4" allowed for weight_dtype but explicitly reject
it for activation_dtype by raising a ValueError when the input is "nf4"; update
the QATConfig/PTQConfig model to use the new activation-specific validator for
activation_dtype and the weight-specific validator for weight_dtype instead of
the single shared validate_ao_dtype.
…pt legacy tests The qlora→lora normalization silently absorbed two combos that should error: `adapter: qlora` with `gptq: True`, and `adapter: qlora` with `load_in_4bit: False` explicitly set. Reject both up front in `normalize_base_quant_inputs` with clear messages. The merge-on-quantized-base errors used `4-bit` / `8-bit` / `GPTQ` — hyphens broke the regex tests that match `.*4bit.*`, `.*8bit.*`, `.*gptq.*`. Restore the hyphenless phrasing. `warn_qlora_zero3_w_use_reentrant` gated on `adapter == "qlora"`, but the PEFT validator demotes that to `lora` before this mode=before validator runs. Broaden the gate to also match the canonical shape (`adapter: lora` + bnb 4-bit / `load_in_4bit`). `test_zero3_qlora_use_reentrant_false` indexed `records[0]`; the new DEPRECATED warning now occupies that slot. Search all records instead.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
…minator ``type(W) is not torch.Tensor`` is True for ``torch.nn.Parameter`` — Parameter is a subclass of Tensor, not the same type. That made every unquantized PEFT base weight (a plain Parameter) take the torchao ``dequantize_weight()`` path, which upcast it to fp32 and broke ``matmul_lora`` when X was fp16 (e.g. the geglu Gemma test). Add ``is_quant_tensor_subclass`` and use it everywhere the kernels decide between bnb / torchao / unquantized.
Description
This PR adds support for torchao's dtype for LoRA training to provide alternative from bitsandbytes which isn't too friendly with FSDP2. Second, it also creates new paradigm for how to config LoRA / QLoRA via backends and weight dtype. Previous methods
load_in_4bitetc still exist for BC.This also provides an alternative to the LoRA kernels by being compile friendly. INT4 simplifies dequant and matmul in LoRA kernel or users can just use torch compile, without graph breaks (hopefully)
Motivation and Context
How has this been tested?
Still untested
AI Usage Disclaimer
Claude
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
bnb,torchao,mxfp4,fp8).Documentation
Bug Fixes