Skip to content

w4a16 nvfp14 quant support#25535

Open
JINO-ROHIT wants to merge 1 commit into
sgl-project:mainfrom
JINO-ROHIT:nvfp4w4a16
Open

w4a16 nvfp14 quant support#25535
JINO-ROHIT wants to merge 1 commit into
sgl-project:mainfrom
JINO-ROHIT:nvfp4w4a16

Conversation

@JINO-ROHIT
Copy link
Copy Markdown
Contributor

@JINO-ROHIT JINO-ROHIT commented May 17, 2026

add support for nvfp4 blackwell w4a16 support.

fixes #25501


CI States

Latest PR Test (Base): ❌ Missing run-ci label — add it to run CI tests.
Latest PR Test (Extra): ❌ Blockedrun-ci is required first.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the CompressedTensorsW4A16Fp4 quantization scheme, including weight initialization, post-loading processing, and the GEMM application logic. The review feedback identifies several critical issues: the derivation of activation scales from weight scales is mathematically incorrect for weight-only quantization, and the implementation lacks necessary padding logic for the FlashInfer/TRTLLM backend to meet kernel alignment requirements. Additionally, the GEMM operation requires activation padding and output slicing to handle these alignments correctly. A code suggestion was also provided to ensure the output shape calculation supports 3D input tensors.

Comment on lines +87 to +90
weight_gs = layer.weight_global_scale.max().to(torch.float32)
input_gs = (1.0 / weight_gs).to(torch.float32)
layer.input_global_scale = Parameter(input_gs, requires_grad=False)
layer.weight_global_scale = Parameter(weight_gs, requires_grad=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Synthesizing a static activation scale (input_gs) from the weight scale (weight_gs) is mathematically incorrect for w4a16 (weight-only) quantization. Since the checkpoint does not provide activation scales, the model should ideally use dynamic quantization for activations (calculating the scale from the input x at runtime) to maintain accuracy. Using a fixed scale derived from weights will likely result in poor model performance.

Comment on lines +92 to +107
if get_fp4_gemm_runner_backend().is_flashinfer_trtllm():
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

weight = layer.weight_packed.data
weight_scale = layer.weight_scale.data

epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
weight_scale = (
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
.reshape(weight_scale.shape)
.view(torch.float8_e4m3fn)
)

layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This section is missing the necessary padding logic for the flashinfer_trtllm backend. The FP4 kernels require specific alignments (e.g., N dimension must be a multiple of 128, and K dimension must be a multiple of 32). Without padding, layers with non-aligned dimensions will cause kernel failures or incorrect results. Please refer to the padding implementation in ModelOptFp4LinearMethod.process_weights_after_loading within modelopt_quant.py and apply similar logic here.

Comment on lines +146 to +154
out = fp4_gemm(
x_fp4,
w,
x_blockscale,
w_blockscale,
layer.alpha,
output_dtype,
w_n,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fp4_gemm call is missing activation padding and output slicing. If the weights are padded to meet alignment requirements (as noted in the process_weights_after_loading feedback), the activations must be padded in the K-dimension to match, and the resulting output must be sliced to remove the N-dimension padding. See ModelOptFp4LinearMethod.apply for reference.

) -> torch.Tensor:
output_dtype = x.dtype
w_n, _ = layer.weight_packed.shape
output_shape = [x.shape[0], w_n]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output_shape calculation assumes a 2D input tensor. If the input x is 3D (e.g., [batch, seq, hidden]), this will lead to an incorrect shape and a runtime error during the view operation. Using x.shape[:-1] ensures compatibility with both 2D and 3D inputs.

Suggested change
output_shape = [x.shape[0], w_n]
output_shape = list(x.shape[:-1]) + [w_n]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AttributeError: 'NoneType' object has no attribute 'num_bits' when loading NVFP4 quantized model with compressed-tensors

1 participant