diff --git a/tensorrt_llm/models/convert_utils.py b/tensorrt_llm/models/convert_utils.py index b3d6680cd..545910a1f 100644 --- a/tensorrt_llm/models/convert_utils.py +++ b/tensorrt_llm/models/convert_utils.py @@ -347,7 +347,7 @@ def smooth_gemm(gemm_weights, [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], dim=0) weight_scales = weight_scales.max(dim=0)[0] - weight_scales.to(float).clamp(min=1e-5) + weight_scales = weight_scales.to(float).clamp(min=1e-5) scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5) @@ -386,7 +386,7 @@ def smooth_gemm_fc1_gate(fc1_weights, [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], dim=0) weight_scales = weight_scales.max(dim=0)[0] - weight_scales.to(float).clamp(min=1e-5) + weight_scales = weight_scales.to(float).clamp(min=1e-5) scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5)