File tree Expand file tree Collapse file tree 2 files changed +3
-6
lines changed
tensorrt_llm/_torch/auto_deploy/custom_ops
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops Expand file tree Collapse file tree 2 files changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -223,14 +223,14 @@ def custom_quant_linear(
223223 "NVFP4 needs weight_scale[0] (per-block vector) and weight_scale[1] (alpha)."
224224 )
225225 cutlass_qscale = weight_scale [0 ]
226- alpha_inv = weight_scale [1 ]
226+ alpha = weight_scale [1 ]
227227
228228 if cutlass_qscale .dtype != torch .uint8 :
229229 raise TypeError (
230230 "NVFP4 expects CUTLASS per-block scale vector in uint8 (same as fused op)."
231231 )
232232
233- inv_w = alpha_inv / inv_x
233+ inv_w = 1 / ( inv_x * alpha )
234234 s2_x = 1.0 / inv_x
235235 s2_w = 1.0 / inv_w
236236
Original file line number Diff line number Diff line change @@ -184,17 +184,14 @@ def test_quant_linear_nvfp4_matches_fused_op(bias):
184184 alpha = alpha_fused ,
185185 )
186186
187- # Unified op (expects modelopt-style per-block scale vector + combined alpha = s_in2*s_w2)
188- alpha_unified = (s_in2 * s_w2 ).to (torch .float32 )
189-
190187 out_unified = torch .ops .auto_deploy .custom_quant_linear (
191188 x ,
192189 weight_fp4 ,
193190 bias ,
194191 [s_in2 ], # input_scale list
195192 [
196193 weight_scale_cutlass ,
197- alpha_unified ,
194+ alpha_fused ,
198195 ], # weight_scale list: [per-block vector, combined alpha]
199196 [], # input_zp
200197 [], # weight_zp
You can’t perform that action at this time.
0 commit comments