diff --git a/aiter/ops/triton/fused_mxfp4_quant.py b/aiter/ops/triton/fused_mxfp4_quant.py index 141bf6d2fe..f6b956738b 100644 --- a/aiter/ops/triton/fused_mxfp4_quant.py +++ b/aiter/ops/triton/fused_mxfp4_quant.py @@ -40,10 +40,7 @@ def fused_rms_mxfp4_quant( - out2: The output matrix with shape (M, N2). - out_res1: The output matrix with shape (M, N1). - if both x2 and res1 provided, return (out1_fp4, out1_bs), out2, out_res1 - if x2 provided, return (out1_fp4, out1_bs), out2 - if res1 provided, return (out1_fp4, out1_bs), out_res1 - if both x2 and res1 not provided, return (out1_fp4, out1_bs) + always returns (out1_fp4, out1_bs), out2, out_res1 """ _LOGGER.info(f"FUSED_RMS_MXFP4_QUANT: inp1={tuple(x1.shape)}")