Skip to content

Commit f8ca1fb

Browse files
committed
minor
Signed-off-by: Frida Hou <[email protected]>
1 parent 1648939 commit f8ca1fb

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,14 @@ def custom_quant_linear(
223223
"NVFP4 needs weight_scale[0] (per-block vector) and weight_scale[1] (alpha)."
224224
)
225225
cutlass_qscale = weight_scale[0]
226-
alpha_inv = weight_scale[1]
226+
alpha = weight_scale[1]
227227

228228
if cutlass_qscale.dtype != torch.uint8:
229229
raise TypeError(
230230
"NVFP4 expects CUTLASS per-block scale vector in uint8 (same as fused op)."
231231
)
232232

233-
inv_w = alpha_inv / inv_x
233+
inv_w = 1 / (inv_x * alpha)
234234
s2_x = 1.0 / inv_x
235235
s2_w = 1.0 / inv_w
236236

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,17 +184,14 @@ def test_quant_linear_nvfp4_matches_fused_op(bias):
184184
alpha=alpha_fused,
185185
)
186186

187-
# Unified op (expects modelopt-style per-block scale vector + combined alpha = s_in2*s_w2)
188-
alpha_unified = (s_in2 * s_w2).to(torch.float32)
189-
190187
out_unified = torch.ops.auto_deploy.custom_quant_linear(
191188
x,
192189
weight_fp4,
193190
bias,
194191
[s_in2], # input_scale list
195192
[
196193
weight_scale_cutlass,
197-
alpha_unified,
194+
alpha_fused,
198195
], # weight_scale list: [per-block vector, combined alpha]
199196
[], # input_zp
200197
[], # weight_zp

0 commit comments

Comments
 (0)