Depend on standalone sageattention.nvfp4 fork + NVFP4 installer#43
Depend on standalone sageattention.nvfp4 fork + NVFP4 installer#43thad0ctor wants to merge 3 commits into
Conversation
…ology The NVFP4 feature set diverges from a vanilla Axolotl install by four extras: cu130 torch (Triton tl.dot_scaled), transformers>=4.57, mslk (pytorch/MSLK wheel index), and the SageAttention-NVFP4 fork (git clone + editable, no PyPI package). scripts/install_nvfp4.sh wires all four (uv-default, --create-venv/--tool pip options) and validates the toolchain.
📝 WalkthroughWalkthroughThis PR migrates NVFP4 flash-attention from a native Axolotl implementation to an external SageAttention-NVFP4 Triton fork. The changes add a new optional dependency, an automated installation script with validation, kernel code re-exports, and user documentation on setup and configuration. ChangesNVFP4 Kernel and Dependency Migration
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
📖 Documentation Preview: Deployed on Netlify from commit 00b2d79 |
|
@coderabbitai review |
✅ Action performedReview finished.
|
There was a problem hiding this comment.
🧹 Nitpick comments (2)
src/axolotl/kernels/attn_nvfp4_flash.py (1)
1-7: 💤 Low valueSimplify module docstring per coding guidelines.
As per coding guidelines, comments should be kept to one short line maximum and should not reference callers. The current docstring is multi-line and explicitly lists callers ("custom_op, monkeypatch, tests, scripts").
Suggested simplification
-"""Re-export shim for native-NVFP4 flash attention. - -The kernel implementation now lives in the standalone ``sageattention.nvfp4`` -fork (a SageAttention fork). This module preserves the historical -``axolotl.kernels.attn_nvfp4_flash`` import path so every existing importer -(custom_op, monkeypatch, tests, scripts) keeps working unchanged. -""" +"""Re-export shim: kernel implementation moved to sageattention.nvfp4 fork, preserving historical import path."""🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/kernels/attn_nvfp4_flash.py` around lines 1 - 7, The module docstring is too verbose and references callers; replace the multi-line docstring at the top of src/axolotl/kernels/attn_nvfp4_flash.py with a single short line describing the module purpose (e.g., "Re-export shim for native NVFP4 flash attention.") and remove any mention of callers or implementation details; update the top-level triple-quoted string in the attn_nvfp4_flash module accordingly.Source: Coding guidelines
scripts/install_nvfp4.sh (1)
65-65: 💤 Low valueConsider dynamic help extraction.
The hardcoded line range
'2,40p'is fragile if the usage comment block grows or shifts. A more maintainable pattern would extract until a marker (e.g.,sed -n '2,/^set -/p'to stop at the first non-comment line), though the current buffer is adequate.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/install_nvfp4.sh` at line 65, The help block extraction is brittle: update the -h|--help) case to dynamically print the script's leading comment block instead of using the fixed sed range '2,40p'; change the sed invocation in the '-h|--help)' branch to print from the top comment through a marker (e.g., use sed -n '1,/^set -/p' or similar) so it stops at the first non-comment/marker line, ensuring the usage stays accurate as the comment block grows or moves.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@scripts/install_nvfp4.sh`:
- Line 65: The help block extraction is brittle: update the -h|--help) case to
dynamically print the script's leading comment block instead of using the fixed
sed range '2,40p'; change the sed invocation in the '-h|--help)' branch to print
from the top comment through a marker (e.g., use sed -n '1,/^set -/p' or
similar) so it stops at the first non-comment/marker line, ensuring the usage
stays accurate as the comment block grows or moves.
In `@src/axolotl/kernels/attn_nvfp4_flash.py`:
- Around line 1-7: The module docstring is too verbose and references callers;
replace the multi-line docstring at the top of
src/axolotl/kernels/attn_nvfp4_flash.py with a single short line describing the
module purpose (e.g., "Re-export shim for native NVFP4 flash attention.") and
remove any mention of callers or implementation details; update the top-level
triple-quoted string in the attn_nvfp4_flash module accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b916f8c3-ce72-497e-8029-f5d1e6dba3a6
📒 Files selected for processing (5)
docs/nvfp4_training.qmdpyproject.tomlscripts/install_nvfp4.shsrc/axolotl/kernels/attn_nvfp4_flash.pysrc/axolotl/kernels/nvfp4_fused_producers.py
Depend on the standalone
sageattention.nvfp4fork for NVFP4 flash attentionExtracts the in-tree native-NVFP4 flash-attention kernel into a standalone
SageAttention fork (
sageattention.nvfp4) and depends on it, plus an installer forthe NVFP4 feature set. Stacks on top of #42 (base:
feat/nvfp4-attn-perf-and-config).Changes
src/axolotl/kernels/attn_nvfp4_flash.py→ thin re-export shim fromsageattention.nvfp4(every importer — custom_op, monkeypatch, tests, scripts — keeps working unchanged).
nvfp4_fused_producers.py— FP4-pack import repointed offmslkto the fork. The attentionpath no longer imports
mslk(the linear/MLP FP4 path still does, by design).pyproject.toml— newnvfp4-attnextra;docs/nvfp4_training.qmd— Installation section.scripts/install_nvfp4.sh— wires the four NVFP4-specific extras a vanilla install lacks:cu130 torch (Triton
tl.dot_scaled),transformers>=4.57,mslk(pytorch/MSLK wheel index),and the SageAttention-NVFP4 fork (git clone + editable, no PyPI package). uv-default, with
--create-venv/--tool pipoptions.Validation (RTX PRO 6000, sm_120)
qwen35-9b-lora-fastest.yaml,torch_compile) viascripts/bench_nvfp4.sh:1.107 s/step, 60.64 GiB, loss 1.027 — matches the documented NVFP4 numbers (1.106 / 60.6).
so ~1.27× faster, −5.1 GiB, and converges where bf16 does not.
Summary by CodeRabbit
Documentation
New Features
Refactor