-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Transform] [Quantization] Add QuTLASS support to vLLM #24440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Transform] [Quantization] Add QuTLASS support to vLLM #24440
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request integrates the QuTLASS library to add support for 4-bit quantization kernels, including new custom ops, benchmarks, and tests. The changes are well-structured. I have two high-severity suggestions: one to improve build reproducibility by pinning the QuTLASS dependency to a specific version, and another to fix a bug in a new test file to prevent future issues.
FetchContent_Declare( | ||
qutlass | ||
GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git | ||
GIT_TAG main |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using main
as the GIT_TAG
can lead to non-reproducible builds and may break the build if there are incompatible changes in the QuTLASS repository's main
branch. It is highly recommended to pin this to a specific commit hash or a release tag (like v0.1.0
as mentioned in the PR description) to ensure build stability and reproducibility.
GIT_TAG v0.1.0
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.) | ||
out_ref = a_dq @ b_dq.transpose(-2, -1) | ||
|
||
out = qutlass.matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The run_problem_ada
function attempts to call qutlass.matmul_ada_mxf4_bf16_tn
, but qutlass
is not defined or imported. This will result in a NameError
. Although this function is not currently called, it's best to fix it to prevent future issues.
To fix this, you should add matmul_ada_mxf4_bf16_tn
to your imports at the top of the file:
from vllm._custom_ops import matmul_mxf4_bf16_tn, fusedQuantizeMx, matmul_ada_mxf4_bf16_tn
And then update this line accordingly.
out = qutlass.matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha) | |
out = matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha) |
@LopezCastroRoberto does this PR support gpt-oss on sm120 ? How to exactly test some mxfp4 models with this PR? Would love to test rtx 6000 pro on this |
return torch.tensor( | ||
hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use our hadamard utility for consistency?
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
return torch.tensor( | |
hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device | |
) | |
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) * group_size**-0.5 |
|
||
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device): | ||
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(b, forward_hadamard_matrix, device) | ||
alpha = torch.Tensor([1.]).to("cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha = torch.Tensor([1.]).to("cuda") | |
alpha = torch.Tensor([1.], device="cuda") |
|
||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): | ||
return torch.tensor( | ||
hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, use our util
'Llama-3.1-70B': [(8192, 8192), (8192, 57344), (28672, 8192)] | ||
} | ||
|
||
for model, layers in MODELS.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please wrap in `if name == "main"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding some user arguments
'Llama-3.1-70B': [(8192, 8192), (8192, 57344), (28672, 8192)] | ||
} | ||
|
||
for model, layers in MODELS.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please wrap in `if name == "main"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider allowing users to specify arguments, that way you don't have to have commented code
vllm/_custom_ops.py
Outdated
|
||
def fusedQuantizeMx(a: torch.Tensor, | ||
b: torch.Tensor, | ||
*, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of this *
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Means all arguments that come after the * must be passed by keyword, not by position. My point was to make the API clearer and less error-prone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fair!
vllm/_custom_ops.py
Outdated
xh_e8m0 = torch.empty(padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device) | ||
|
||
if method=="quest": | ||
return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because these functions have a return value, you'll want to register a fake function so torch compile works right
if hasattr(torch.ops._C, "_qutlass_C"):
@register_fake("_C::_qutlass_C::fusedQuantizeMxQuest")
def fake_qutlass_mx_quest(a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return (torch.empty(...), torch.empty(...))
vllm/qutlass_utils/utils.py
Outdated
output_block_stride, | ||
BLOCK_ROWS: tl.constexpr, | ||
BLOCK_COLS: tl.constexpr, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this over-indented? I think we should standardize on 4 space indent
return (a + b - 1) // b | ||
|
||
|
||
def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just as a style thing, consider calling triton_mx_block_rearrange
in cases where you want to use the triton kernel and to_blocked
otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about keeping one to_blocked
but making the backend explicit (e.g. backend="torch" | "triton" | "auto")
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both good!
# Quantize activation on-the-fly | ||
def run(): | ||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(a, forward_hadamard_matrix, global_scale) | ||
input_hf_scale_block = to_blocked(input_hf_e8m0, True).view(-1,K//16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the triton jit affect benchmarked runtime? Ie, first time compile causes the first graph to take longer than normal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes—the very first time is slower, but after that it's cached
@voipmonitor This PR supports dense models only, and it's perfectly fine to use an RTX 6000 Pro. We will add usage examples to this PR soon. We’re actively working on MoE support in QuTLASS—stay tuned :) |
f9ca647
to
dce5334
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please convert these to use pytest like other tests and add a skipif based on compute capability. You can add these tests to the blackwell test runner
vllm/.buildkite/test-pipeline.yaml
Line 772 in 8db2939
- label: Blackwell Test # 38 min |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this require some minimum CUDA version?
Fixed |
Signed-off-by: Andrei Panferov <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
d0805ae
to
f4a6f15
Compare
@LopezCastroRoberto it looks like the blackwell tests are broken at the moment |
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: Andrei Panferov <[email protected]> minor fixes eager works eager tests custom ops fake fix eager works eager tests removed extra op style Signed-off-by: Andrei Panferov <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: Roberto L. Castro <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Purpose
This pull request brings in the QuTLASS library: https://github.com/iST-DASLab/qutlass
QuTLASS is a high-performance library designed for low-precision kernel support in deep learning quantization, built on top of NVIDIA CUTLASS.
QuTLASS v0.1.0 introduces 4-bit microscaling routines tailored for Large Language Model (LLM) inference on NVIDIA Blackwell GPUs.
Microbenchmarking
benchmarks/kernels/bench_mxfp4_qutlass.py
benchmarks/kernels/bench_nvfp4_qutlass.py
QuTLASS performance on a single Qwen3-32B layer with NVIDIA RTX5090 GPU
QuTLASS performance on a single Llama-3.1-70B layer with NVIDIA B200 GPU
[WIP] End-to-end
python benchmarks/benchmark_latency.py
daslab-testing/Llama-3.3-70B-Instruct-FPQuant-GPTQ-MXFP4-hadamard
meta-llama/Llama-3.3-70B-Instruct
FP16
MXFP4
Testing
tests/kernels/quantization/test_mxfp4_qutlass.py
tests/kernels/quantization/test_nvfp4_qutlass.py