From ba42e631e5f80f2ee8f5b377441a0402260db2d9 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Thu, 25 Apr 2024 22:44:49 +0000 Subject: [PATCH 1/4] wip --- .../mlc_llm/quantization/fp8_quantization.py | 86 ++++++++++--------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/python/mlc_llm/quantization/fp8_quantization.py b/python/mlc_llm/quantization/fp8_quantization.py index b879951df4..65e678459b 100644 --- a/python/mlc_llm/quantization/fp8_quantization.py +++ b/python/mlc_llm/quantization/fp8_quantization.py @@ -381,51 +381,57 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa raise NotImplementedError( f"Only max and cast runtimes are supported for FP8 activations, {self.runtime} is unsupported." ) + if indptr.ndim == 2: + assert indptr.shape[0] == 1 + from mlc_llm.op import moe_matmul - workspace = nn.op.wrap_nested( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr((4096 * 1024,)), - dtype="uint8", - runtime_device_index=0, - ), - "relax.alloc_tensor", - ) - - batch_size, in_features = x.shape - num_local_experts, out_features, _ = self.q_weight.shape - - if self.runtime == "max-calibration": - func = "cutlass.group_gemm_scale_fp16_sm90" + print(x.dtype, self.q_weight.dtype, indptr.dtype) + return moe_matmul.gemv(x, self.q_weight, indptr) else: - a_format = self.activation_dtype.split("_")[0] - w_format = self.weight_dtype.split("_")[0] - func = f"cutlass.group_gemm_{a_format}_{w_format}_fp16" + workspace = nn.op.wrap_nested( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr((4096 * 1024,)), + dtype="uint8", + runtime_device_index=0, + ), + "relax.alloc_tensor", + ) - if self.runtime == "cast": - func = func + "_host_scale" - total_scale = 1.0 - else: - if self.runtime != "max-calibration" and self.weight_dtype == "e4m3_float8": - # for calibration, q_scale is already used to dequantize the weights - total_scale = local_scale * self.q_scale + batch_size, in_features = x.shape + num_local_experts, out_features, _ = self.q_weight.shape + + if self.runtime == "max-calibration": + func = "cutlass.group_gemm_scale_fp16_sm90" else: - total_scale = local_scale - total_scale = nn.op.astype(total_scale, dtype="float32") + a_format = self.activation_dtype.split("_")[0] + w_format = self.weight_dtype.split("_")[0] + func = f"cutlass.group_gemm_{a_format}_{w_format}_fp16" - return nn.op.extern( - func, - [ - x, - w, - indptr, - workspace, - total_scale, - ], - out=nn.Tensor.placeholder( - (batch_size, out_features), - dtype=self.weight_config.model_dtype, - ), - ) + if self.runtime == "cast": + func = func + "_host_scale" + total_scale = 1.0 + else: + if self.runtime != "max-calibration" and self.weight_dtype == "e4m3_float8": + # for calibration, q_scale is already used to dequantize the weights + total_scale = local_scale * self.q_scale + else: + total_scale = local_scale + total_scale = nn.op.astype(total_scale, dtype="float32") + + return nn.op.extern( + func, + [ + x, + w, + indptr, + workspace, + total_scale, + ], + out=nn.Tensor.placeholder( + (batch_size, out_features), + dtype=self.weight_config.model_dtype, + ), + ) # TODO(csullivan): Refactor Linear and MixtralExperts to shared base with common code From cb178ba3b45f720f160c886eb453763ef9ecc26d Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 26 Apr 2024 00:04:51 +0000 Subject: [PATCH 2/4] done --- python/mlc_llm/quantization/fp8_quantization.py | 12 ++++++++++-- .../mlc_llm/quantization/per_tensor_quantization.py | 1 - 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/mlc_llm/quantization/fp8_quantization.py b/python/mlc_llm/quantization/fp8_quantization.py index 65e678459b..cdab85bede 100644 --- a/python/mlc_llm/quantization/fp8_quantization.py +++ b/python/mlc_llm/quantization/fp8_quantization.py @@ -385,8 +385,16 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa assert indptr.shape[0] == 1 from mlc_llm.op import moe_matmul - print(x.dtype, self.q_weight.dtype, indptr.dtype) - return moe_matmul.gemv(x, self.q_weight, indptr) + out = moe_matmul.gemv(x, w, indptr) + fp32_out = nn.op.astype(out, dtype="float32") + if self.runtime != "max-calibration" and self.weight_dtype == "e4m3_float8": + # for calibration, q_scale is already used to dequantize the weights + total_scale = local_scale * self.q_scale + else: + total_scale = local_scale + total_scale = nn.op.astype(total_scale, dtype="float32") + scaled_out = fp32_out * total_scale + return nn.op.astype(scaled_out, dtype="float16") else: workspace = nn.op.wrap_nested( relax.op.builtin.alloc_tensor( diff --git a/python/mlc_llm/quantization/per_tensor_quantization.py b/python/mlc_llm/quantization/per_tensor_quantization.py index d7ea2c9f17..cb9a72ecaa 100644 --- a/python/mlc_llm/quantization/per_tensor_quantization.py +++ b/python/mlc_llm/quantization/per_tensor_quantization.py @@ -357,7 +357,6 @@ def __init__( # pylint: disable=too-many-arguments @classmethod def from_linear(cls, src: nn.Linear, config: PerTensorQuantize) -> "PerTensorQuantizeLinear": - if ( DataType(config.weight_dtype).type_code in [ From 64a8a4a94c4ae4c9d63f0e59c3dde6643fa1f630 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Mon, 29 Apr 2024 16:58:20 +0000 Subject: [PATCH 3/4] dedup --- .../mlc_llm/quantization/fp8_quantization.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/mlc_llm/quantization/fp8_quantization.py b/python/mlc_llm/quantization/fp8_quantization.py index cdab85bede..21a1f108eb 100644 --- a/python/mlc_llm/quantization/fp8_quantization.py +++ b/python/mlc_llm/quantization/fp8_quantization.py @@ -381,17 +381,24 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa raise NotImplementedError( f"Only max and cast runtimes are supported for FP8 activations, {self.runtime} is unsupported." ) + + def get_total_scale(runtime, weight_dtype, q_scale=None): + if runtime != "max-calibration" and weight_dtype == "e4m3_float8": + assert q_scale is not None + # for calibration, q_scale is already used to dequantize the weights + total_scale = local_scale * q_scale + else: + total_scale = local_scale + return total_scale + if indptr.ndim == 2: + # Single batch specialization. Use gemv kernels instead assert indptr.shape[0] == 1 from mlc_llm.op import moe_matmul out = moe_matmul.gemv(x, w, indptr) fp32_out = nn.op.astype(out, dtype="float32") - if self.runtime != "max-calibration" and self.weight_dtype == "e4m3_float8": - # for calibration, q_scale is already used to dequantize the weights - total_scale = local_scale * self.q_scale - else: - total_scale = local_scale + total_scale = get_total_scale(self.runtime, self.weight_dtype, self.q_scale) total_scale = nn.op.astype(total_scale, dtype="float32") scaled_out = fp32_out * total_scale return nn.op.astype(scaled_out, dtype="float16") @@ -419,11 +426,7 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa func = func + "_host_scale" total_scale = 1.0 else: - if self.runtime != "max-calibration" and self.weight_dtype == "e4m3_float8": - # for calibration, q_scale is already used to dequantize the weights - total_scale = local_scale * self.q_scale - else: - total_scale = local_scale + total_scale = get_total_scale(self.runtime, self.weight_dtype, self.q_scale) total_scale = nn.op.astype(total_scale, dtype="float32") return nn.op.extern( From c582c5dfb4d61096584c395dd1d09b7ee7fe5c73 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Mon, 29 Apr 2024 17:02:18 +0000 Subject: [PATCH 4/4] remove space --- python/mlc_llm/quantization/per_tensor_quantization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mlc_llm/quantization/per_tensor_quantization.py b/python/mlc_llm/quantization/per_tensor_quantization.py index cb9a72ecaa..d7ea2c9f17 100644 --- a/python/mlc_llm/quantization/per_tensor_quantization.py +++ b/python/mlc_llm/quantization/per_tensor_quantization.py @@ -357,6 +357,7 @@ def __init__( # pylint: disable=too-many-arguments @classmethod def from_linear(cls, src: nn.Linear, config: PerTensorQuantize) -> "PerTensorQuantizeLinear": + if ( DataType(config.weight_dtype).type_code in [