Skip to content

feat: add torchao int4, nf4, int8 LoRA QLoRA support#3417

Open
NanoCode012 wants to merge 12 commits into
mainfrom
feat/torchao-qlora
Open

feat: add torchao int4, nf4, int8 LoRA QLoRA support#3417
NanoCode012 wants to merge 12 commits into
mainfrom
feat/torchao-qlora

Conversation

@NanoCode012

@NanoCode012 NanoCode012 commented Feb 16, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR adds support for torchao's dtype for LoRA training to provide alternative from bitsandbytes which isn't too friendly with FSDP2. Second, it also creates new paradigm for how to config LoRA / QLoRA via backends and weight dtype. Previous methods load_in_4bit etc still exist for BC.

This also provides an alternative to the LoRA kernels by being compile friendly. INT4 simplifies dequant and matmul in LoRA kernel or users can just use torch compile, without graph breaks (hopefully)

     # New torchao QLoRA
     adapter: lora
     peft:
       backend: torchao
       weight_dtype: int4 # or nf4 to be equivalent to bnb

     # New torchao LoRA
     adapter: lora
     peft:
       backend: torchao
       weight_dtype: int8

     # New bnb QLoRA
     adapter: lora
     peft:
       backend: bnb
       weight_dtype: nf4

     # New bnb 8-bit LoRA
     adapter: lora
     peft:
       backend: bnb
       weight_dtype: int8

Motivation and Context

  • Torch compile alternative to LoRA kernels
  • Refactor how to set LoRA and QLoRA

How has this been tested?

Still untested

AI Usage Disclaimer

Claude

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added support for QLoRA using TorchAO as the base quantization backend, including NF4 weight dtype support.
    • Introduced structured quantization configuration system supporting multiple backends (bnb, torchao, mxfp4, fp8).
    • Improved FSDP2 compatibility with TorchAO-quantized models.
  • Documentation

    • Added comprehensive guide for configuring QLoRA with TorchAO.
    • Added example training configuration for QLoRA with TorchAO.
  • Bug Fixes

    • Improved weight dequantization handling across quantization backends.

Review Change Stack

@winglian winglian force-pushed the feat/torchao-qlora branch from 970b2a6 to 33c495d Compare May 29, 2026 14:36
@coderabbitai

coderabbitai Bot commented May 29, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 1c0876f4-59d1-4993-a984-a5987c3feeb5

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR introduces TorchAO as a structured quantization backend alternative to bitsandbytes for QLoRA training, unifies dequantization logic across both backends, migrates from adapter: qlora to load_in_4bit-based detection, and gates bitsandbytes-specific FSDP2 patches when torchao is active.

Changes

TorchAO Quantization Backend Integration

Layer / File(s) Summary
Structured Quantization Config Schema
src/axolotl/utils/schemas/model.py, src/axolotl/utils/schemas/peft.py, src/axolotl/utils/schemas/enums.py, src/axolotl/utils/schemas/quantization.py, src/axolotl/utils/schemas/validation.py
ModelQuantizationConfig enforces exactly one backend (bnb, torchao, mxfp4, fp8) via discriminator model_validator. PeftConfig.normalize_base_quant_inputs normalizes legacy adapter:qlora and load_in_*bit flags into canonical structured form with deprecation warnings. TorchAOQuantDType.nf4 enum member and validation support nf4 dtype. Config validators detect conflicting backends and incompatible combinations.
Unified Dequantization Kernel
src/axolotl/kernels/quantize.py, src/axolotl/kernels/lora.py, tests/e2e/kernels/test_quantize.py
dequantize_weight helper centralizes dequantization for torchao tensor subclasses (via get_original_weight or fallback dequantize) and bitsandbytes weights, with optional transpose. matmul_lora accepts transpose parameter. DoRA and LoRA backward paths updated to use unified function, removing explicit transpose logic around dequantize calls. Tests cover plain tensors, affine subclasses, and NF4 subclass dispatch.
Model Loader TorchAO Integration
src/axolotl/loaders/model.py
Added _torchao_subconfig and _normalize_torchao_config helpers. is_qlora_and_fsdp_enabled now detects adapter:lora + load_in_4bit under FSDP. New is_torchao_qlora property. _set_quantization_config maps torchao weight_dtype to TorchAoConfig, detects base-model re-quantization conflicts, normalizes mxfp4/fp8 structured forms. Device-map override skipped for torchao+FSDP+QLoRA. prepare_model_for_kbit_training skipped for torchao paths.
LoRA Adapter Loading with TorchAO
src/axolotl/loaders/adapter.py, src/axolotl/monkeypatch/peft/utils.py
load_lora detects torchao usage, skips bitsandbytes setup_quantized_meta_for_peft when active. Patches PEFT TorchAO dispatch for non-int8 weights via patch_peft_torchao_dispatch to prevent INT4/NF4 failures. FSDP CPU-efficient training metadata gated to skip torchao paths.
FSDP2 Patch Gating
src/axolotl/loaders/patch_manager.py
_apply_fsdp2_bnb_patches detects torchao via _torchao_subconfig, skips bitsandbytes initialization patches when torchao is active, preserving patch sequence for bitsandbytes 4-bit/8-bit under FSDP2.
Configuration Normalization
src/axolotl/cli/merge_lora.py, src/axolotl/core/builders/causal.py
Updated is_qlora_and_fsdp_enabled detection. simulate_nf4 now keys off load_in_4bit only (removing adapter == "qlora" checks). training_arguments_kwargs["qlora"] set when load_in_4bit is true. Inline comments document upstream qlora-to-lora demotion behavior.
Comprehensive Test Coverage
tests/test_loaders.py, tests/utils/lora/test_config_validation_lora.py
Removed adapter="qlora" from loader parametrization. Added TestStructuredQuantizationConfig covering bnb/torchao shorthand, backend discrimination, legacy string compatibility, conflict validation, merge_lora constraints, and DoRA compatibility. Added torchao-specific tests for weight-dtype mapping, NVFP4/FP8 special cases, already-quantized checkpoint rejection, and mxfp4 error paths.
Documentation and Examples
docs/qlora_torchao.qmd, examples/llama-3/qlora-torchao.yaml, _quarto.yml
New comprehensive QLoRA+TorchAO guide covering structured discriminator, adapter auto-promotion, legacy deprecation, MXFP4 MoE limitations, mixed-quant handling, DoRA/merge constraints, PEFT LoRA limitations, FSDP2 native support. Working llama-3 example config with torchao int4 quantization. Navigation updated.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes


Possibly related PRs


Suggested labels

scheduled_release


Suggested reviewers

  • winglian
  • djsaunde
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main feature addition: torchao support for int4, nf4, and int8 quantization methods with LoRA/QLoRA training, which aligns with the comprehensive changeset introducing this capability.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/torchao-qlora

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

winglian added 7 commits May 29, 2026 16:47
Land the contributor PR after rebasing onto main and closing the gaps
found during audit.

Behavior changes
- peft.backend: torchao + weight_dtype: int8 now stays as adapter=lora
  (matching bnb int8 semantics) instead of being auto-promoted to qlora.
- Unsupported torchao weight_dtypes (fp8, nvfp4, mxfp4) are rejected at
  validation with a clear pointer to the QAT/PTQ flow.
- Merging a torchao adapter requires merge_method=legacy; the
  memory-efficient merger simulates bnb NF4 and would silently
  mis-merge torchao tensor subclasses.
- DoRA paths in kernels/lora.py route through dequantize_weight so
  DoRA + torchao works end-to-end (the previous bare dequantize calls
  would have failed on AffineQuantizedTensor / NF4Tensor).

Bug fixes uncovered while landing
- model.py: switch from the deprecated string quant_type API
  (TorchAoConfig(quant_type="int4_weight_only")) to the object-based
  Int4WeightOnlyConfig / Int8WeightOnlyConfig API required by modern
  transformers.
- model.py: import NF4WeightOnlyConfig from torchao.prototype._nf4tensor_api
  (with a fallback to the old torchao.dtypes path) — the original
  location no longer exists in torchao >= 0.13.
- model.py: NF4WeightOnlyConfig now takes no constructor arguments;
  set block_size / scaler_block_size as attributes.

Coverage
- ModelLoader.is_torchao_qlora now matches both adapter=lora and
  adapter=qlora to keep the bnb-skipping branches consistent for the
  int8 case.
- model.py's _set_quantization_config branch now triggers for adapter
  in (lora, qlora) so int8 torchao gets its TorchAoConfig.

Docs + examples
- docs/qlora_torchao.qmd: new page covering backends, weight_dtype
  table, constraints, FSDP2.
- examples/llama-3/qlora-torchao.yaml: minimal config using the new
  peft block.

Tests
- tests/utils/lora/test_config_validation_lora.py: torchao+int8 stays
  lora; fp8/nvfp4/mxfp4 rejected; merge_lora requires legacy;
  DoRA + torchao allowed.
- tests/test_loaders.py: TorchAoConfig is wired with
  Int4WeightOnlyConfig / Int8WeightOnlyConfig / NF4WeightOnlyConfig.
- tests/e2e/kernels/test_quantize.py: dequantize_weight against fake
  AffineQuantizedTensor / NF4Tensor subclasses (no CUDA needed).

Validated locally with a CUDA smoke test on SmolLM2-135M: torchao
int8 LoRA loads with AffineQuantizedTensor base weights, forward +
backward produce gradients on all 420 trainable params.
Earlier pass rejected fp8/nvfp4/mxfp4 at the schema layer, telling
users to use QAT/PTQ instead. That was wrong:

- NVFP4 has a real weight-only torchao config (NVFP4WeightOnlyConfig
  in torchao.prototype.mx_formats) — it's a 4-bit quant, perfectly
  suited to QLoRA. Now auto-promotes adapter lora -> qlora and
  builds NVFP4WeightOnlyConfig at load.
- FP8 (float8_e4m3fn) has Float8WeightOnlyConfig in torchao.quantization
  — a one-byte-per-weight quant that mirrors INT8's role. Keeps
  adapter as lora.
- MXFP4 is the genuine 'no weight-only flavor' case. The schema now
  passes it through; the loader raises with a pointer to
  quantize_moe_experts: true for MoE models (which is where MXFP4
  LoRA actually lives, via the ScatterMoE-LoRA path landed in #3663)
  and to qat/ptq for inference-time MXFP4.

CUDA smoke-tested on SmolLM2-135M:
- weight_dtype: fp8  -> Float8WeightOnlyConfig, forward+backward OK
- weight_dtype: nvfp4 (group_size=16) -> NVFP4WeightOnlyConfig, OK
- weight_dtype: mxfp4 -> loader error pointing to quantize_moe_experts

Docs and the dtype table updated; schema/loader tests extended.
…overrides

peft.backend: torchao installs a single TorchAoConfig covering every
linear layer. It composes badly with axolotl's other quant mechanisms,
but the prior code silently picked a winner:

- model_quantization_config (Mxfp4Config / FineGrainedFP8Config) gets
  overwritten by our TorchAoConfig later in the same function.
- A checkpoint with embedded quantization_config (gpt-oss MXFP4,
  pre-quantized AWQ / GPTQ / BNB) wins via the earlier if-branch in
  _set_quantization_config; peft.backend is silently ignored.
- quantize_moe_experts: true would race with TorchAoConfig over the
  same expert tensors.
- gptq: true is a separate path entirely.

Now:

- SystemValidationMixin.check_torchao_backend_exclusivity rejects
  peft.backend: torchao + (model_quantization_config |
  quantize_moe_experts | gptq) at validation with a pointer to
  docs/qlora_torchao.qmd.
- ModelLoader._set_quantization_config raises when the base model's
  checkpoint already advertises a quant_method and peft.backend is
  also set (the conflict only resolvable post-load_model_config).

Documents the boundary: mixed-quant flows (experts MXFP4 + attention
bf16, gpt-oss-style) drop peft.backend and use the per-mechanism
config (quantize_moe_experts or the checkpoint's quant_method)
directly. peft.backend is for uniform base-quant only.
…he awq/gptq/bnb branch

The prior loader-time check sat inside the peft.backend elif branch,
so checkpoints with quant_method in (awq, gptq, bitsandbytes) hit the
earlier if-branch first and silently overwrote model_kwargs's
quantization_config — peft.backend got dropped on the floor.

Move the check to the top of _set_quantization_config so it fires
for any non-empty model_config.quantization_config, including the
realistic ones that motivated this audit:

- gpt-oss native MXFP4 (quant_method: mxfp4)
- AMD Quark MXFP4 with a per-module exclude list, e.g.
  amd/Kimi-K2.6-MXFP4: experts in MXFP4, ~305 modules excluded
  (lm_head, every attention projection, vision tower, mm_projector)
- AWQ / GPTQ / bitsandbytes pre-quantized checkpoints

Tests parametrize across all five quant_methods. Docs gain a section
naming AMD Quark MXFP4 as the canonical mixed-quant example and
restate the recommendation: drop peft.backend so the checkpoint's
own quantization_config flows through unchanged.
…ation_config

Replace the peft.backend / peft.weight_dtype / peft.group_size shape
with a structured discriminator on the existing
model_quantization_config field. One namespace for all base-model
quant; peft.backend drops out entirely.

User-facing surface:

  # bnb 4-bit QLoRA (replaces adapter: qlora + load_in_4bit: true)
  adapter: lora
  model_quantization_config:
    bnb:
      weight_dtype: nf4

  # torchao QLoRA
  adapter: lora
  model_quantization_config:
    torchao:
      weight_dtype: int4
      # group_size: 128

  # Legacy string form (Mxfp4Config / FineGrainedFP8Config) keeps working
  # via the same field. Equivalent structured form:
  model_quantization_config:
    mxfp4:
      config_kwargs: {}

Schema:
  ModelQuantizationConfig(BaseModel) is a discriminated union with
  exactly one of bnb / torchao / mxfp4 / fp8 set. The top-level field
  accepts Literal["Mxfp4Config", "FineGrainedFP8Config"] |
  ModelQuantizationConfig | None.

Auto-promotion:
  Moved out of the peft block into LoraConfig.auto_detect_qlora, which
  reads the structured form. bnb.nf4 sets load_in_4bit and promotes
  lora -> qlora; bnb.int8 sets load_in_8bit; torchao 4-bit dtypes
  (int4/nf4/nvfp4) promote lora -> qlora; torchao int8/fp8 stay as
  weight-only LoRA.

Conflict surfaces (validation.py + model.py) updated to gate on
model_quantization_config.torchao instead of peft.backend:
  - + quantize_moe_experts: true   -> rejected at validation
  - + gptq: true                   -> rejected at validation
  - + load_in_4bit / load_in_8bit  -> rejected at validation
  - + checkpoint with embedded quant_method -> rejected at load time
    (covers Quark, mxfp4, awq, gptq, bnb — the AMD Kimi-K2.6-MXFP4 case
    Wing called out).

Internals:
  axolotl.utils.config.validate_config returns nested fields as dicts
  via model_dump. Two helpers in loaders/model.py (_mqc_branch and
  _torchao_subconfig) accept either form so direct-Pydantic test
  construction and post-validate dict access both work.

Docs (docs/qlora_torchao.qmd) and the example
(examples/llama-3/qlora-torchao.yaml) rewritten around the new shape.
Schema tests (24) and loader tests (43) rewritten to exercise it.

CUDA-validated on SmolLM2-135M for torchao int8 and fp8 paths: model
loads with the structured config, LoRA injects, forward+backward
produces gradients on all 420 trainable params.
Per maintainer feedback: "qlora is just lora with an nf4 base weight" —
don't carry it as a distinct adapter name. Demote `adapter: qlora` to
`adapter: lora` in the validator and key all internal "is this
QLoRA?" decisions off the actual base-weight quant state instead of
the adapter name.

User surface
- The recommended shape is now uniformly `adapter: lora` plus one of:
    `model_quantization_config: {bnb: {weight_dtype: nf4}}` (terse)
    `load_in_4bit: true`                                     (legacy bnb)
    `model_quantization_config: {torchao: {weight_dtype: int4}}`
- Legacy `adapter: qlora` configs keep working unchanged: a new
  `normalize_adapter_qlora` validator demotes them and, if no
  base-quant choice was spelled out, auto-sets `load_in_4bit: true`
  (the legacy shorthand's implicit meaning). Emits a DEPRECATED log
  with the migration path.
- `adapter: qlora` + `load_in_8bit: true` is now rejected as
  ambiguous (QLoRA is a 4-bit thing).

Codebase
- `is_qlora_and_fsdp_enabled` now keys off `load_in_4bit`, not the
  adapter name.
- The bnb 4-bit branch in `_set_quantization_config` fires on
  `adapter == lora and load_in_4bit` (the validator guarantees
  qlora is always normalized to lora upstream).
- `AxolotlTrainingArguments.qlora` (dead read; nothing in axolotl
  consumes it) is set from `load_in_4bit` instead of the adapter name.
- `validate_qlora` (mode=after) no longer gates the merge-bans on
  `adapter == qlora`; merge into an 8-bit/4-bit/GPTQ base is rejected
  regardless of how the user spelled the quant.
- `merge_lora` CLI's `simulate_nf4` drops the dead
  `_original_adapter == "qlora"` check; `_original_load_in_4bit`
  covers all bnb-4-bit cases (and the validator sets it for
  legacy-qlora configs upstream).

Deprecation warnings (per request)
- `adapter: qlora` → DEPRECATED log pointing at the new shape.
- `model_quantization_config: "Mxfp4Config" | "FineGrainedFP8Config"`
  (string form) → DEPRECATED log naming the equivalent structured
  form (with `config_kwargs` carried over).
- Both fire at config-load time, so they show up on every `axolotl
  train` / `axolotl merge-lora` / `axolotl preprocess` invocation
  for the affected configs.

Tests
- 69 passing. Legacy `adapter: qlora` test cases updated to assert
  the demoted shape (`adapter == "lora"`, `load_in_4bit == True`).
- The `adapter: qlora` parametrize axis in
  `test_set_quantization_config` removed; legacy paths exercised
  via the validator's normalization.
…ig first, load_in_4bit/8bit deprecated

Previous commit kept load_in_4bit / load_in_8bit as the user-facing
knobs and set them internally to drive the bnb loader branch. Per
Wing: "i thought we were trying to get rid of load_in_4bit: true?"
— right. The structured form is the source of truth; the legacy
flags are deprecated user inputs.

One `normalize_base_quant_inputs` validator now does both halves
in lockstep (separate validators ran out of order in Pydantic v2,
which is what was leaving load_in_4bit unset for bare `adapter:
qlora` configs):

1. Translate every legacy spelling into the canonical structured
   form:
     - `adapter: qlora`               → adapter: lora + bnb nf4
     - `adapter: qlora` + load_in_4bit → adapter: lora + bnb nf4
     - `load_in_4bit: true` (alone)   → bnb nf4
     - `load_in_8bit: true` (alone)   → bnb int8
   Emits a DEPRECATED warning on each path.
2. Mirror the structured form back into load_in_4bit/8bit so the
   downstream loader code that still reads them sees a consistent
   state.

`load_in_4bit` / `load_in_8bit` field descriptions now begin with
`DEPRECATED:` so the JSON schema and autogen docs flag them.

Docs (docs/qlora_torchao.qmd) and the headline example
(examples/llama-3/qlora-torchao.yaml) reworked so the canonical form
is the only shape shown; legacy forms appear only in the Deprecations
section with their migration target.

Tests: 59 passing — including bare `adapter: qlora` (no load_in_4bit
written) which validates through to load_in_4bit=True via the
combined validator. CUDA smoke (torchao int8 on SmolLM2-135M) still
loads, forward+backward, grads on 420 params.
@winglian winglian force-pushed the feat/torchao-qlora branch from 1dc14fa to 4b23b3e Compare May 29, 2026 16:48
@winglian

Copy link
Copy Markdown
Collaborator

State of this branch

Rebased onto current main and reworked end-to-end. Reads as one
cohesive change rather than seven incremental tweaks now that the
shape has settled.

Canonical config shape

adapter: lora
model_quantization_config:
  bnb:                          # or torchao / mxfp4 / fp8
    weight_dtype: nf4

model_quantization_config is the only base-quant knob users should
write. QLoRA is just adapter: lora with a 4-bit base; there is no
separate adapter type.

Backend weight_dtype Quant config installed
bnb nf4 BitsAndBytesConfig(load_in_4bit=True)
bnb int8 BitsAndBytesConfig(load_in_8bit=True)
torchao int4 Int4WeightOnlyConfig
torchao nf4 NF4WeightOnlyConfig
torchao nvfp4 NVFP4WeightOnlyConfig (group_size=16)
torchao int8 Int8WeightOnlyConfig
torchao fp8 Float8WeightOnlyConfig
mxfp4 transformers.Mxfp4Config(**config_kwargs)
fp8 transformers.FineGrainedFP8Config(**config_kwargs)

torchao.weight_dtype: mxfp4 is rejected with a pointer to
quantize_moe_experts: true for MoE models — MXFP4 has no
weight-only torchao config for arbitrary linears.

Deprecations (kept working, warn at config-load)

Legacy input Translated to
adapter: qlora (with or without load_in_4bit: true) adapter: lora + {bnb: {weight_dtype: nf4}}
load_in_4bit: true (alone) {bnb: {weight_dtype: nf4}}
load_in_8bit: true (alone) {bnb: {weight_dtype: int8}}
model_quantization_config: "Mxfp4Config" + model_quantization_config_kwargs {mxfp4: {config_kwargs: …}}
model_quantization_config: "FineGrainedFP8Config" (string) {fp8: {config_kwargs: …}}
adapter: qlora + load_in_8bit: true hard error (ambiguous)

A single normalize_base_quant_inputs validator does both the
translate and the mirror-back-into-load_in_4bit/8bit in lockstep
so downstream loader code that still keys off the flags sees a
consistent state. The flag fields' description: strings begin
with DEPRECATED: so the JSON schema reflects it.

Conflict surfaces

model_quantization_config.torchao is a uniform-base-quant
shorthand and does not compose with other quant mechanisms.
Conflicts are caught explicitly rather than producing silent
overrides:

Combination Caught at
…torchao + quantize_moe_experts: true validation
…torchao + gptq: true validation
…torchao + load_in_4bit/load_in_8bit validation
…torchao + checkpoint with embedded quant_method load time
Discriminator with zero or multiple of bnb/torchao/mxfp4/fp8 set validation

The load-time check covers every quant_method, including modern
exclusion-list shapes (e.g. amd/Kimi-K2.6-MXFP4 uses
quant_method: quark with ~305 excluded modules — attention,
lm_head, vision tower, mm_projector all stay in their native
dtype). For mixed-quant flows the recommended path is to drop
model_quantization_config entirely and let the checkpoint's own
quantization_config flow through transformers unchanged, or use
quantize_moe_experts: true for bf16 MoE checkpoints that want
MXFP4 experts at load time (the path landed in #3663).

Kernels

  • dequantize_weight wrapper in kernels/quantize.py covers both
    bnb QuantState and torchao tensor subclasses
    (AffineQuantizedTensor.dequantize(),
    NF4Tensor.get_original_weight()).
  • matmul_lora gained a transpose= parameter so kernel paths can
    dispatch to either bnb (transpose-before-dequant) or torchao
    (transpose-after-dequant) layouts.
  • Every dequantize() callsite in LoRA_MLP/LoRA_QKV/LoRA_O
    backward paths and the DoRA scale path now routes through
    dequantize_weight, so DoRA + torchao works.
  • PEFT's TorchaoLoraLinear only supports INT8; the dispatch_torchao
    patch makes INT4/NF4/NVFP4 fall back to standard Linear LoRA so
    the kernels dequantize the base weight on access.

Other rough edges closed during the audit

  • TorchAoConfig(quant_type="int4_weight_only", …) (string API)
    was rejected by current transformers — switched to the object API
    with Int4WeightOnlyConfig / Int8WeightOnlyConfig /
    NF4WeightOnlyConfig.
  • NF4WeightOnlyConfig moved from torchao.dtypes._nf4tensor_api
    to torchao.prototype._nf4tensor_api around torchao 0.13 and
    takes no constructor args — import with a fallback, set fields
    as attributes.
  • FSDP2 bnb shims are skipped automatically when the torchao
    backend is active.
  • axolotl merge-lora requires merge_method: legacy with a
    torchao adapter; the memory-efficient merger simulates bnb NF4
    and would silently mis-merge torchao tensor subclasses.

Verified

  • pre-commit clean (ruff, mypy, bandit) on changed files.
  • pytest tests/utils/lora/test_config_validation_lora.py tests/test_loaders.py tests/e2e/kernels/test_quantize.py77 passing, covering the structured discriminator, every deprecation translation, the conflict-surface matrix above, the checkpoint-quant-method catch (across mxfp4, quark with exclusion list, awq, gptq, bitsandbytes), and dequantize_weight against fake torchao tensor subclasses.
  • CUDA smoke on SmolLM2-135M for the torchao paths that don't depend on external prerequisites: int8 + fp8 + nvfp4 all load with the right *WeightOnlyConfig, PEFT injects LoRA, forward + backward produces gradients on all 420 trainable params. int4 needs the mslk ≥ 1.0.0 package upstream of axolotl; nf4 needs model dims satisfying torchao's n_scalers % scaler_block_size == 0 (135M's k_proj is too small).
  • Added: docs/qlora_torchao.qmd (wired into the Advanced Features sidebar), examples/llama-3/qlora-torchao.yaml, 11 new tests covering the structured shape and deprecations.

Known follow-up (not in this PR)

Migrate the loader to read cfg.model_quantization_config
directly and drop load_in_4bit / load_in_8bit as fields
entirely. Touches roughly 10+ sites that currently read
cfg.load_in_4bit; cleaner as a focused removal once this lands.

@winglian winglian marked this pull request as ready for review May 29, 2026 23:53
@github-actions

github-actions Bot commented May 30, 2026

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://6a1bb5ce49528c402bf6562e--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit f9f280b

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (2)
tests/test_loaders.py (1)

110-143: ⚡ Quick win

Use adapter="lora" in the torchao loader tests.

These cases still build loader state with adapter="qlora", but the schema normalizer is supposed to demote that upstream. Keeping the tests on qlora means they don't cover the runtime contract that the torchao helpers actually key on.

Also applies to: 158-168, 229-237, 277-287

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_loaders.py` around lines 110 - 143, The tests (e.g.,
test_set_quantization_config_torchao_qlora) are using adapter="qlora" but the
torchao loader helpers expect the normalized upstream value "lora"; update the
parametrized adapter values and any hardcoded self.cfg.adapter assignments in
the torchao-related tests (including the other occurrences flagged at the other
ranges) from "qlora" to "lora" so the tests exercise the runtime contract the
loader keys on (adjust the parametrization tuple entries and the
self.cfg.adapter assignments in those test methods).
src/axolotl/utils/schemas/validation.py (1)

1232-1235: ⚡ Quick win

Trim these new comments to one-line WHY notes.

These blocks are mostly describing WHAT the validators do and exceed the repo’s comment style for src/axolotl/**.

As per coding guidelines, "Only add comments when explaining the WHY behind non-obvious logic, hidden constraints, or workarounds for specific bugs. Do not comment on WHAT code does ... Comments should be a maximum of one short line."

Also applies to: 1248-1250, 1274-1282

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/utils/schemas/validation.py` around lines 1232 - 1235, Trim the
multi-line explanatory comments around the structured "bnb" form and legacy
checks to a single short WHY note: explain why the structured bnb case is exempt
from the legacy Mxfp4Config/FineGrainedFP8Config string-form check (i.e.,
because auto_detect_qlora sets load_in_4bit/load_in_8bit), and replace similar
WHAT-style blocks at the subsequent comment sites (the blocks that mention
auto_detect_qlora, Mxfp4Config, FineGrainedFP8Config, load_in_4bit/load_in_8bit
and the mxfp4/fp8 branches) with one-line WHY comments so the validation logic
remains documented but conforms to the repo style.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/monkeypatch/peft/utils.py`:
- Around line 97-107: Save the original
peft.tuners.lora.torchao.dispatch_torchao before replacing it (e.g., store it on
peft_torchao._axolotl_orig_dispatch) and change patch_peft_torchao_dispatch to
install a patched_dispatch that only returns None for non-INT8 adapters but
delegates to the saved original for INT8/torchao dequantization cases (inspect
lora_config or adapter_name to detect INT8); keep the _axolotl_patched flag but
do not permanently drop the original dispatcher, or alternatively restore
peft_torchao.dispatch_torchao back to the saved original after the adapter load
completes, ensuring no leak across later adapter loads and preserving the
original behavior for INT8 TorchaoLoraLinear.

In `@src/axolotl/utils/schemas/model.py`:
- Around line 38-46: The schema currently lists "mxfp4" as an allowed value for
the weight_dtype Field but the loader rejects it later; update the Pydantic
model in model.py to enforce this at validation by removing "mxfp4" from the
Literal type for weight_dtype (i.e., change
Literal["int4","nf4","nvfp4","int8","fp8","mxfp4"] to exclude "mxfp4") or
alternatively add a Pydantic validator on weight_dtype that raises a ValueError
when the value == "mxfp4"; reference the weight_dtype Field in model.py to
implement the change so invalid YAML is rejected at schema validation time.

In `@src/axolotl/utils/schemas/peft.py`:
- Around line 249-259: The mirror step in src/axolotl/utils/schemas/peft.py
currently uses data.setdefault(...) so existing conflicting legacy flags can
remain; update the block that reads mqc = data.get("model_quantization_config")
and inspects bnb/weight_dtype to explicitly set data["load_in_4bit"] =
True/False and data["load_in_8bit"] = True/False according to weight_dtype
(e.g., if weight_dtype == "nf4" set load_in_4bit True and load_in_8bit False; if
"int8" set load_in_8bit True and load_in_4bit False; otherwise ensure both are
False or unset), so the canonical bnb config always wins over legacy flags used
by downstream loaders.

In `@src/axolotl/utils/schemas/quantization.py`:
- Around line 19-20: The validator validate_ao_dtype currently maps "nf4" →
TorchAOQuantDType.nf4 and is reused for both activation_dtype and weight_dtype,
so activation_dtype: "nf4" incorrectly passes; fix this by splitting the logic
into two validators (e.g., validate_weight_ao_dtype and
validate_activation_ao_dtype) or by adding a field-specific check: keep "nf4"
allowed for weight_dtype but explicitly reject it for activation_dtype by
raising a ValueError when the input is "nf4"; update the QATConfig/PTQConfig
model to use the new activation-specific validator for activation_dtype and the
weight-specific validator for weight_dtype instead of the single shared
validate_ao_dtype.

---

Nitpick comments:
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1232-1235: Trim the multi-line explanatory comments around the
structured "bnb" form and legacy checks to a single short WHY note: explain why
the structured bnb case is exempt from the legacy
Mxfp4Config/FineGrainedFP8Config string-form check (i.e., because
auto_detect_qlora sets load_in_4bit/load_in_8bit), and replace similar
WHAT-style blocks at the subsequent comment sites (the blocks that mention
auto_detect_qlora, Mxfp4Config, FineGrainedFP8Config, load_in_4bit/load_in_8bit
and the mxfp4/fp8 branches) with one-line WHY comments so the validation logic
remains documented but conforms to the repo style.

In `@tests/test_loaders.py`:
- Around line 110-143: The tests (e.g.,
test_set_quantization_config_torchao_qlora) are using adapter="qlora" but the
torchao loader helpers expect the normalized upstream value "lora"; update the
parametrized adapter values and any hardcoded self.cfg.adapter assignments in
the torchao-related tests (including the other occurrences flagged at the other
ranges) from "qlora" to "lora" so the tests exercise the runtime contract the
loader keys on (adjust the parametrization tuple entries and the
self.cfg.adapter assignments in those test methods).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 875be089-d749-4775-a411-98d1b242264f

📥 Commits

Reviewing files that changed from the base of the PR and between bf19bff and 4b23b3e.

📒 Files selected for processing (19)
  • _quarto.yml
  • docs/qlora_torchao.qmd
  • examples/llama-3/qlora-torchao.yaml
  • src/axolotl/cli/merge_lora.py
  • src/axolotl/core/builders/causal.py
  • src/axolotl/kernels/lora.py
  • src/axolotl/kernels/quantize.py
  • src/axolotl/loaders/adapter.py
  • src/axolotl/loaders/model.py
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/peft/utils.py
  • src/axolotl/utils/schemas/enums.py
  • src/axolotl/utils/schemas/model.py
  • src/axolotl/utils/schemas/peft.py
  • src/axolotl/utils/schemas/quantization.py
  • src/axolotl/utils/schemas/validation.py
  • tests/e2e/kernels/test_quantize.py
  • tests/test_loaders.py
  • tests/utils/lora/test_config_validation_lora.py

Comment on lines +97 to +107
if getattr(peft_torchao, "_axolotl_patched", False):
return

def patched_dispatch(target, adapter_name, lora_config, **kwargs):
# Return None so PEFT falls back to standard Linear LoRA layers.
# Our LoRA kernels handle torchao dequantization explicitly.
return None

peft_torchao.dispatch_torchao = patched_dispatch
peft_torchao._axolotl_patched = True
LOG.info("Patched PEFT dispatch_torchao to skip TorchaoLoraLinear")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Restore/scope the PEFT dispatch_torchao monkeypatch to prevent leaking across later adapter loads.

patch_peft_torchao_dispatch() replaces peft.tuners.lora.torchao.dispatch_torchao with a stub that always returns None and never restores the original dispatcher; if a non-INT8 torchao adapter is loaded first, later INT8 torchao loads in the same worker will still observe the stub and skip PEFT’s TorchaoLoraLinear dispatch even though src/axolotl/loaders/adapter.py avoids calling the patch for INT8. Preserve the original dispatcher and restore it after the relevant load, or have the stub delegate back to the original for INT8.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/monkeypatch/peft/utils.py` around lines 97 - 107, Save the
original peft.tuners.lora.torchao.dispatch_torchao before replacing it (e.g.,
store it on peft_torchao._axolotl_orig_dispatch) and change
patch_peft_torchao_dispatch to install a patched_dispatch that only returns None
for non-INT8 adapters but delegates to the saved original for INT8/torchao
dequantization cases (inspect lora_config or adapter_name to detect INT8); keep
the _axolotl_patched flag but do not permanently drop the original dispatcher,
or alternatively restore peft_torchao.dispatch_torchao back to the saved
original after the adapter load completes, ensuring no leak across later adapter
loads and preserving the original behavior for INT8 TorchaoLoraLinear.

Comment on lines +38 to +46
weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8", "mxfp4"] = Field(
json_schema_extra={
"description": (
"torchao base-weight dtype. int4/nf4/nvfp4 → QLoRA; int8/fp8 "
"→ weight-only LoRA; mxfp4 is unsupported as a base-quant "
"shorthand (use quantize_moe_experts for MoE MXFP4)."
)
}
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject torchao.mxfp4 at the schema layer.

"mxfp4" is advertised as a valid torchao.weight_dtype here, but the loader always rejects it later. That lets an invalid YAML pass validation and only fail during model load, which is the wrong layer for this check.

♻️ Proposed fix
-    weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8", "mxfp4"] = Field(
+    weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8"] = Field(
         json_schema_extra={
             "description": (
                 "torchao base-weight dtype. int4/nf4/nvfp4 → QLoRA; int8/fp8 "
-                "→ weight-only LoRA; mxfp4 is unsupported as a base-quant "
-                "shorthand (use quantize_moe_experts for MoE MXFP4)."
+                "→ weight-only LoRA."
             )
         }
     )

As per coding guidelines, "Config validation must use Pydantic schemas defined in src/axolotl/utils/schemas/."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8", "mxfp4"] = Field(
json_schema_extra={
"description": (
"torchao base-weight dtype. int4/nf4/nvfp4 → QLoRA; int8/fp8 "
"→ weight-only LoRA; mxfp4 is unsupported as a base-quant "
"shorthand (use quantize_moe_experts for MoE MXFP4)."
)
}
)
weight_dtype: Literal["int4", "nf4", "nvfp4", "int8", "fp8"] = Field(
json_schema_extra={
"description": (
"torchao base-weight dtype. int4/nf4/nvfp4 → QLoRA; int8/fp8 "
"→ weight-only LoRA."
)
}
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/utils/schemas/model.py` around lines 38 - 46, The schema
currently lists "mxfp4" as an allowed value for the weight_dtype Field but the
loader rejects it later; update the Pydantic model in model.py to enforce this
at validation by removing "mxfp4" from the Literal type for weight_dtype (i.e.,
change Literal["int4","nf4","nvfp4","int8","fp8","mxfp4"] to exclude "mxfp4") or
alternatively add a Pydantic validator on weight_dtype that raises a ValueError
when the value == "mxfp4"; reference the weight_dtype Field in model.py to
implement the change so invalid YAML is rejected at schema validation time.

Comment on lines +249 to +259
# Step 2: mirror the structured form back into load_in_4bit /
# load_in_8bit for downstream loader compat.
mqc = data.get("model_quantization_config")
if isinstance(mqc, dict):
bnb = mqc.get("bnb")
if isinstance(bnb, dict):
weight_dtype = bnb.get("weight_dtype")
if weight_dtype == "nf4":
data.setdefault("load_in_4bit", True)
elif weight_dtype == "int8":
data.setdefault("load_in_8bit", True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Force the mirrored load_in_*bit flags to match the canonical bnb config.

Using setdefault() here leaves conflicting legacy flags untouched. A config like model_quantization_config: {bnb: {weight_dtype: nf4}} plus load_in_4bit: false or load_in_8bit: true will pass validation, but the loader still keys off load_in_*bit and can take the wrong branch.

♻️ Proposed fix
         if isinstance(mqc, dict):
             bnb = mqc.get("bnb")
             if isinstance(bnb, dict):
                 weight_dtype = bnb.get("weight_dtype")
                 if weight_dtype == "nf4":
-                    data.setdefault("load_in_4bit", True)
+                    data["load_in_4bit"] = True
+                    data["load_in_8bit"] = False
                 elif weight_dtype == "int8":
-                    data.setdefault("load_in_8bit", True)
+                    data["load_in_8bit"] = True
+                    data["load_in_4bit"] = False

As per coding guidelines, "Config validation must use Pydantic schemas defined in src/axolotl/utils/schemas/."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Step 2: mirror the structured form back into load_in_4bit /
# load_in_8bit for downstream loader compat.
mqc = data.get("model_quantization_config")
if isinstance(mqc, dict):
bnb = mqc.get("bnb")
if isinstance(bnb, dict):
weight_dtype = bnb.get("weight_dtype")
if weight_dtype == "nf4":
data.setdefault("load_in_4bit", True)
elif weight_dtype == "int8":
data.setdefault("load_in_8bit", True)
# Step 2: mirror the structured form back into load_in_4bit /
# load_in_8bit for downstream loader compat.
mqc = data.get("model_quantization_config")
if isinstance(mqc, dict):
bnb = mqc.get("bnb")
if isinstance(bnb, dict):
weight_dtype = bnb.get("weight_dtype")
if weight_dtype == "nf4":
data["load_in_4bit"] = True
data["load_in_8bit"] = False
elif weight_dtype == "int8":
data["load_in_8bit"] = True
data["load_in_4bit"] = False
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/utils/schemas/peft.py` around lines 249 - 259, The mirror step in
src/axolotl/utils/schemas/peft.py currently uses data.setdefault(...) so
existing conflicting legacy flags can remain; update the block that reads mqc =
data.get("model_quantization_config") and inspects bnb/weight_dtype to
explicitly set data["load_in_4bit"] = True/False and data["load_in_8bit"] =
True/False according to weight_dtype (e.g., if weight_dtype == "nf4" set
load_in_4bit True and load_in_8bit False; if "int8" set load_in_8bit True and
load_in_4bit False; otherwise ensure both are False or unset), so the canonical
bnb config always wins over legacy flags used by downstream loaders.

Comment on lines +19 to +20
if v == "nf4":
return TorchAOQuantDType.nf4

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🌐 Web query:

Does torchao 0.17.0 support NF4 as an activation quantization dtype, or is NF4 weight-only?

💡 Result:

In torchao 0.17.0, NF4 is exclusively a weight-only quantization data type [1][2][3]. It is not supported for activation quantization [4]. Torchao provides NF4 support primarily through the nf4_weight_only() configuration and NF4Tensor subclass, which are designed to reduce memory usage by storing weights in the 4-bit Normal Float (NF4) format [1][2]. During computation, these weights are typically dequantized to a higher precision (such as bfloat16) [5]. Conversely, torchao's activation quantization schemes—such as those using Float8, NVFP4, or MXFP4—are explicitly categorized as dynamic or static activation quantization configurations, none of which utilize NF4 for the activations themselves [4][6][7].

Citations:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the target file around the mentioned lines and the validator definition.
FILE="src/axolotl/utils/schemas/quantization.py"
echo "== $FILE =="
nl -ba "$FILE" | sed -n '1,120p'

# Find where validate_ao_dtype is referenced within this schemas folder.
echo
echo "== References to validate_ao_dtype =="
rg -n "validate_ao_dtype" src/axolotl/utils/schemas | head -n 50

# Show the schema classes / fields that call into validate_ao_dtype (activation_dtype / weight_dtype).
echo
echo "== Where activation_dtype / weight_dtype are validated =="
rg -n "activation_dtype|weight_dtype" src/axolotl/utils/schemas | head -n 80

Repository: axolotl-ai-cloud/axolotl

Length of output: 177


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="src/axolotl/utils/schemas/quantization.py"

echo "== git ls-files =="
git ls-files "$FILE" || true

echo
echo "== $FILE (first 200 lines) =="
if [ -f "$FILE" ]; then
  sed -n '1,200p' "$FILE" | sed -n '1,200p'
else
  echo "File not found: $FILE"
  exit 0
fi

echo
echo "== validate_ao_dtype references =="
rg -n "validate_ao_dtype" src/axolotl/utils/schemas | head -n 50

echo
echo "== activation_dtype / weight_dtype validations =="
rg -n "(activation_dtype|weight_dtype)" src/axolotl/utils/schemas | head -n 80

Repository: axolotl-ai-cloud/axolotl

Length of output: 5228


Disallow nf4 for activation_dtype (schema currently permits it)
src/axolotl/utils/schemas/quantization.py maps "nf4"TorchAOQuantDType.nf4 in validate_ao_dtype(), and that same validator is attached to both activation_dtype and weight_dtype in QATConfig/PTQConfig, so activation_dtype: nf4 passes Pydantic validation. Since torchao 0.17.0 treats NF4 as weight-only (not an activation quantization dtype), reject "nf4" for activation_dtype (split validators or add field-specific constraints).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/utils/schemas/quantization.py` around lines 19 - 20, The
validator validate_ao_dtype currently maps "nf4" → TorchAOQuantDType.nf4 and is
reused for both activation_dtype and weight_dtype, so activation_dtype: "nf4"
incorrectly passes; fix this by splitting the logic into two validators (e.g.,
validate_weight_ao_dtype and validate_activation_ao_dtype) or by adding a
field-specific check: keep "nf4" allowed for weight_dtype but explicitly reject
it for activation_dtype by raising a ValueError when the input is "nf4"; update
the QATConfig/PTQConfig model to use the new activation-specific validator for
activation_dtype and the weight-specific validator for weight_dtype instead of
the single shared validate_ao_dtype.

…pt legacy tests

The qlora→lora normalization silently absorbed two combos that should
error: `adapter: qlora` with `gptq: True`, and `adapter: qlora` with
`load_in_4bit: False` explicitly set. Reject both up front in
`normalize_base_quant_inputs` with clear messages.

The merge-on-quantized-base errors used `4-bit` / `8-bit` / `GPTQ` —
hyphens broke the regex tests that match `.*4bit.*`, `.*8bit.*`,
`.*gptq.*`. Restore the hyphenless phrasing.

`warn_qlora_zero3_w_use_reentrant` gated on `adapter == "qlora"`, but
the PEFT validator demotes that to `lora` before this mode=before
validator runs. Broaden the gate to also match the canonical shape
(`adapter: lora` + bnb 4-bit / `load_in_4bit`).

`test_zero3_qlora_use_reentrant_false` indexed `records[0]`; the new
DEPRECATED warning now occupies that slot. Search all records instead.
@codecov

codecov Bot commented May 30, 2026

Copy link
Copy Markdown

…minator

``type(W) is not torch.Tensor`` is True for ``torch.nn.Parameter`` —
Parameter is a subclass of Tensor, not the same type. That made every
unquantized PEFT base weight (a plain Parameter) take the torchao
``dequantize_weight()`` path, which upcast it to fp32 and broke
``matmul_lora`` when X was fp16 (e.g. the geglu Gemma test).

Add ``is_quant_tensor_subclass`` and use it everywhere the kernels
decide between bnb / torchao / unquantized.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants