w4a16 nvfp14 quant support#25535
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the CompressedTensorsW4A16Fp4 quantization scheme, including weight initialization, post-loading processing, and the GEMM application logic. The review feedback identifies several critical issues: the derivation of activation scales from weight scales is mathematically incorrect for weight-only quantization, and the implementation lacks necessary padding logic for the FlashInfer/TRTLLM backend to meet kernel alignment requirements. Additionally, the GEMM operation requires activation padding and output slicing to handle these alignments correctly. A code suggestion was also provided to ensure the output shape calculation supports 3D input tensors.
| weight_gs = layer.weight_global_scale.max().to(torch.float32) | ||
| input_gs = (1.0 / weight_gs).to(torch.float32) | ||
| layer.input_global_scale = Parameter(input_gs, requires_grad=False) | ||
| layer.weight_global_scale = Parameter(weight_gs, requires_grad=False) |
There was a problem hiding this comment.
Synthesizing a static activation scale (input_gs) from the weight scale (weight_gs) is mathematically incorrect for w4a16 (weight-only) quantization. Since the checkpoint does not provide activation scales, the model should ideally use dynamic quantization for activations (calculating the scale from the input x at runtime) to maintain accuracy. Using a fixed scale derived from weights will likely result in poor model performance.
| if get_fp4_gemm_runner_backend().is_flashinfer_trtllm(): | ||
| from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a | ||
|
|
||
| weight = layer.weight_packed.data | ||
| weight_scale = layer.weight_scale.data | ||
|
|
||
| epilogue_tile_m = 128 | ||
| weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) | ||
| weight_scale = ( | ||
| shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) | ||
| .reshape(weight_scale.shape) | ||
| .view(torch.float8_e4m3fn) | ||
| ) | ||
|
|
||
| layer.weight_scale = Parameter(weight_scale, requires_grad=False) | ||
| layer.weight_packed = Parameter(weight, requires_grad=False) |
There was a problem hiding this comment.
This section is missing the necessary padding logic for the flashinfer_trtllm backend. The FP4 kernels require specific alignments (e.g., N dimension must be a multiple of 128, and K dimension must be a multiple of 32). Without padding, layers with non-aligned dimensions will cause kernel failures or incorrect results. Please refer to the padding implementation in ModelOptFp4LinearMethod.process_weights_after_loading within modelopt_quant.py and apply similar logic here.
| out = fp4_gemm( | ||
| x_fp4, | ||
| w, | ||
| x_blockscale, | ||
| w_blockscale, | ||
| layer.alpha, | ||
| output_dtype, | ||
| w_n, | ||
| ) |
There was a problem hiding this comment.
The fp4_gemm call is missing activation padding and output slicing. If the weights are padded to meet alignment requirements (as noted in the process_weights_after_loading feedback), the activations must be padded in the K-dimension to match, and the resulting output must be sliced to remove the N-dimension padding. See ModelOptFp4LinearMethod.apply for reference.
| ) -> torch.Tensor: | ||
| output_dtype = x.dtype | ||
| w_n, _ = layer.weight_packed.shape | ||
| output_shape = [x.shape[0], w_n] |
There was a problem hiding this comment.
The output_shape calculation assumes a 2D input tensor. If the input x is 3D (e.g., [batch, seq, hidden]), this will lead to an incorrect shape and a runtime error during the view operation. Using x.shape[:-1] ensures compatibility with both 2D and 3D inputs.
| output_shape = [x.shape[0], w_n] | |
| output_shape = list(x.shape[:-1]) + [w_n] |
add support for nvfp4 blackwell w4a16 support.
fixes #25501
CI States
Latest PR Test (Base): ❌ Missing
run-cilabel — add it to run CI tests.Latest PR Test (Extra): ❌ Blocked —
run-ciis required first.