Skip to content

Conversation

LopezCastroRoberto
Copy link
Contributor

@LopezCastroRoberto LopezCastroRoberto commented Sep 8, 2025

Purpose

This pull request brings in the QuTLASS library: https://github.com/iST-DASLab/qutlass

QuTLASS is a high-performance library designed for low-precision kernel support in deep learning quantization, built on top of NVIDIA CUTLASS.

QuTLASS v0.1.0 introduces 4-bit microscaling routines tailored for Large Language Model (LLM) inference on NVIDIA Blackwell GPUs.

  • Online rotations:
    • Fused transform + quantization + scale computation.
      • Rotation matrices loaded at runtime, allowing any transformation to be applied.
    • Support for both NVFP4 and MXFP4 microscaling formats.
    • Multiple rotation sizes (16/32/64/128).
  • MXFP4 matmul kernel support powered by CUTLASS.
    • QuTLASS is compatible with any matmul backend supporting microscaling formats (e.g., CUTLASS, FlashInfer).
  • Multiple quantization schemes:

Microbenchmarking

  • benchmarks/kernels/bench_mxfp4_qutlass.py
  • benchmarks/kernels/bench_nvfp4_qutlass.py

Llama-32B MXFP4:MXFP4 Llama-32B NVFP4:NVFP4
QuTLASS performance on a single Qwen3-32B layer with NVIDIA RTX5090 GPU

Llama-70B MXFP4:MXFP4 Llama-70B NVFP4:NVFP4
QuTLASS performance on a single Llama-3.1-70B layer with NVIDIA B200 GPU

[WIP] End-to-end

  • python benchmarks/benchmark_latency.py
    • daslab-testing/Llama-3.3-70B-Instruct-FPQuant-GPTQ-MXFP4-hadamard
    • meta-llama/Llama-3.3-70B-Instruct

vLLM_b200

FP16

Quantization Llama: MMLU-CoT GSM8k Hellaswag Winogrande Average Recovery %
N/A 0.866 0.951 0.862 0.849 0.882 -

MXFP4

Quantization Llama: MMLU-CoT GSM8k Hellaswag Winogrande Average Recovery %
RTN 0.834 0.927 0.839 0.815 0.854 96.8
RTN + HAD (GS32) 0.839 0.936 0.841 0.836 0.863 97.8
GPTQ 0.838 0.945 0.844 0.826 0.863 97.9
GPTQ + HAD (GS32) 0.848 0.945 0.847 0.831 0.868 98.4

Testing

  • tests/kernels/quantization/test_mxfp4_qutlass.py
  • tests/kernels/quantization/test_nvfp4_qutlass.py

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates the QuTLASS library to add support for 4-bit quantization kernels, including new custom ops, benchmarks, and tests. The changes are well-structured. I have two high-severity suggestions: one to improve build reproducibility by pinning the QuTLASS dependency to a specific version, and another to fix a bug in a new test file to prevent future issues.

FetchContent_Declare(
qutlass
GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git
GIT_TAG main
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using main as the GIT_TAG can lead to non-reproducible builds and may break the build if there are incompatible changes in the QuTLASS repository's main branch. It is highly recommended to pin this to a specific commit hash or a release tag (like v0.1.0 as mentioned in the PR description) to ensure build stability and reproducibility.

    GIT_TAG v0.1.0

b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.)
out_ref = a_dq @ b_dq.transpose(-2, -1)

out = qutlass.matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The run_problem_ada function attempts to call qutlass.matmul_ada_mxf4_bf16_tn, but qutlass is not defined or imported. This will result in a NameError. Although this function is not currently called, it's best to fix it to prevent future issues.

To fix this, you should add matmul_ada_mxf4_bf16_tn to your imports at the top of the file:

from vllm._custom_ops import matmul_mxf4_bf16_tn, fusedQuantizeMx, matmul_ada_mxf4_bf16_tn

And then update this line accordingly.

Suggested change
out = qutlass.matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha)
out = matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha)

@voipmonitor
Copy link

@LopezCastroRoberto does this PR support gpt-oss on sm120 ? How to exactly test some mxfp4 models with this PR? Would love to test rtx 6000 pro on this

@jeejeelee jeejeelee requested a review from mgoin September 8, 2025 15:33
Comment on lines 37 to 45
return torch.tensor(
hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use our hadamard utility for consistency?

from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
Suggested change
return torch.tensor(
hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device
)
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) * group_size**-0.5


def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(b, forward_hadamard_matrix, device)
alpha = torch.Tensor([1.]).to("cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
alpha = torch.Tensor([1.]).to("cuda")
alpha = torch.Tensor([1.], device="cuda")


def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return torch.tensor(
hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device
Copy link
Contributor

@kylesayrs kylesayrs Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, use our util

'Llama-3.1-70B': [(8192, 8192), (8192, 57344), (28672, 8192)]
}

for model, layers in MODELS.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please wrap in `if name == "main"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding some user arguments

'Llama-3.1-70B': [(8192, 8192), (8192, 57344), (28672, 8192)]
}

for model, layers in MODELS.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please wrap in `if name == "main"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider allowing users to specify arguments, that way you don't have to have commented code


def fusedQuantizeMx(a: torch.Tensor,
b: torch.Tensor,
*,
Copy link
Contributor

@kylesayrs kylesayrs Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of this *?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Means all arguments that come after the * must be passed by keyword, not by position. My point was to make the API clearer and less error-prone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair!

xh_e8m0 = torch.empty(padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device)

if method=="quest":
return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0)
Copy link
Contributor

@kylesayrs kylesayrs Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these functions have a return value, you'll want to register a fake function so torch compile works right

if hasattr(torch.ops._C, "_qutlass_C"):
    @register_fake("_C::_qutlass_C::fusedQuantizeMxQuest")
    def fake_qutlass_mx_quest(a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        return (torch.empty(...), torch.empty(...))

output_block_stride,
BLOCK_ROWS: tl.constexpr,
BLOCK_COLS: tl.constexpr,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this over-indented? I think we should standardize on 4 space indent

return (a + b - 1) // b


def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a style thing, consider calling triton_mx_block_rearrange in cases where you want to use the triton kernel and to_blocked otherwise

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about keeping one to_blocked but making the backend explicit (e.g. backend="torch" | "triton" | "auto")?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both good!

# Quantize activation on-the-fly
def run():
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(a, forward_hadamard_matrix, global_scale)
input_hf_scale_block = to_blocked(input_hf_e8m0, True).view(-1,K//16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the triton jit affect benchmarked runtime? Ie, first time compile causes the first graph to take longer than normal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes—the very first time is slower, but after that it's cached

@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft September 9, 2025 10:24
@LopezCastroRoberto
Copy link
Contributor Author

@voipmonitor This PR supports dense models only, and it's perfectly fine to use an RTX 6000 Pro. We will add usage examples to this PR soon.

We’re actively working on MoE support in QuTLASS—stay tuned :)

@mergify mergify bot added documentation Improvements or additions to documentation deepseek Related to DeepSeek models frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models qwen Related to Qwen models rocm Related to AMD ROCm structured-output labels Sep 11, 2025
@mgoin mgoin added quantization kernel and removed tool-calling llama Related to Llama models qwen Related to Qwen models deepseek Related to DeepSeek models labels Sep 22, 2025
@mgoin mgoin removed this from Tool Calling Sep 22, 2025
@mergify mergify bot added the performance Performance-related issues label Sep 22, 2025
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please convert these to use pytest like other tests and add a skipif based on compute capability. You can add these tests to the blackwell test runner

- label: Blackwell Test # 38 min

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this require some minimum CUDA version?

@BlackSamorez
Copy link

BlackSamorez commented Sep 30, 2025

Fixed register_fakes. They had a wrong namespace (_C::_qutlass_C instead of just _qutlass_C) and quite a few kernels (matmul_mxf4_bf16_tn, matmul_ada_mxf4_bf16_tn, fused_quantize_nv) didn't have fake impls at all.

BlackSamorez and others added 2 commits October 1, 2025 20:53
Signed-off-by: Andrei Panferov <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
@mgoin
Copy link
Member

mgoin commented Oct 2, 2025

@LopezCastroRoberto it looks like the blackwell tests are broken at the moment ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.

LopezCastroRoberto and others added 8 commits October 2, 2025 09:10
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: Andrei Panferov <[email protected]>

minor fixes

eager works

eager tests

custom ops fake fix

eager works

eager tests

removed extra op

style

Signed-off-by: Andrei Panferov <[email protected]>
Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LopezCastroRoberto.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 6, 2025
LopezCastroRoberto and others added 2 commits October 6, 2025 07:21
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: Roberto L. Castro <[email protected]>
@mergify mergify bot removed the needs-rebase label Oct 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build kernel performance Performance-related issues quantization ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants