diff --git a/python/mlc_llm/quantization/fp8_quantization.py b/python/mlc_llm/quantization/fp8_quantization.py index b879951df4..21a1f108eb 100644 --- a/python/mlc_llm/quantization/fp8_quantization.py +++ b/python/mlc_llm/quantization/fp8_quantization.py @@ -381,51 +381,68 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa raise NotImplementedError( f"Only max and cast runtimes are supported for FP8 activations, {self.runtime} is unsupported." ) - - workspace = nn.op.wrap_nested( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr((4096 * 1024,)), - dtype="uint8", - runtime_device_index=0, - ), - "relax.alloc_tensor", - ) - - batch_size, in_features = x.shape - num_local_experts, out_features, _ = self.q_weight.shape - - if self.runtime == "max-calibration": - func = "cutlass.group_gemm_scale_fp16_sm90" - else: - a_format = self.activation_dtype.split("_")[0] - w_format = self.weight_dtype.split("_")[0] - func = f"cutlass.group_gemm_{a_format}_{w_format}_fp16" - - if self.runtime == "cast": - func = func + "_host_scale" - total_scale = 1.0 - else: - if self.runtime != "max-calibration" and self.weight_dtype == "e4m3_float8": + + def get_total_scale(runtime, weight_dtype, q_scale=None): + if runtime != "max-calibration" and weight_dtype == "e4m3_float8": + assert q_scale is not None # for calibration, q_scale is already used to dequantize the weights - total_scale = local_scale * self.q_scale + total_scale = local_scale * q_scale else: total_scale = local_scale + return total_scale + + if indptr.ndim == 2: + # Single batch specialization. Use gemv kernels instead + assert indptr.shape[0] == 1 + from mlc_llm.op import moe_matmul + + out = moe_matmul.gemv(x, w, indptr) + fp32_out = nn.op.astype(out, dtype="float32") + total_scale = get_total_scale(self.runtime, self.weight_dtype, self.q_scale) total_scale = nn.op.astype(total_scale, dtype="float32") + scaled_out = fp32_out * total_scale + return nn.op.astype(scaled_out, dtype="float16") + else: + workspace = nn.op.wrap_nested( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr((4096 * 1024,)), + dtype="uint8", + runtime_device_index=0, + ), + "relax.alloc_tensor", + ) - return nn.op.extern( - func, - [ - x, - w, - indptr, - workspace, - total_scale, - ], - out=nn.Tensor.placeholder( - (batch_size, out_features), - dtype=self.weight_config.model_dtype, - ), - ) + batch_size, in_features = x.shape + num_local_experts, out_features, _ = self.q_weight.shape + + if self.runtime == "max-calibration": + func = "cutlass.group_gemm_scale_fp16_sm90" + else: + a_format = self.activation_dtype.split("_")[0] + w_format = self.weight_dtype.split("_")[0] + func = f"cutlass.group_gemm_{a_format}_{w_format}_fp16" + + if self.runtime == "cast": + func = func + "_host_scale" + total_scale = 1.0 + else: + total_scale = get_total_scale(self.runtime, self.weight_dtype, self.q_scale) + total_scale = nn.op.astype(total_scale, dtype="float32") + + return nn.op.extern( + func, + [ + x, + w, + indptr, + workspace, + total_scale, + ], + out=nn.Tensor.placeholder( + (batch_size, out_features), + dtype=self.weight_config.model_dtype, + ), + ) # TODO(csullivan): Refactor Linear and MixtralExperts to shared base with common code