Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 57 additions & 40 deletions python/mlc_llm/quantization/fp8_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,51 +381,68 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa
raise NotImplementedError(
f"Only max and cast runtimes are supported for FP8 activations, {self.runtime} is unsupported."
)

workspace = nn.op.wrap_nested(
relax.op.builtin.alloc_tensor(
relax.ShapeExpr((4096 * 1024,)),
dtype="uint8",
runtime_device_index=0,
),
"relax.alloc_tensor",
)

batch_size, in_features = x.shape
num_local_experts, out_features, _ = self.q_weight.shape

if self.runtime == "max-calibration":
func = "cutlass.group_gemm_scale_fp16_sm90"
else:
a_format = self.activation_dtype.split("_")[0]
w_format = self.weight_dtype.split("_")[0]
func = f"cutlass.group_gemm_{a_format}_{w_format}_fp16"

if self.runtime == "cast":
func = func + "_host_scale"
total_scale = 1.0
else:
if self.runtime != "max-calibration" and self.weight_dtype == "e4m3_float8":

def get_total_scale(runtime, weight_dtype, q_scale=None):
if runtime != "max-calibration" and weight_dtype == "e4m3_float8":
assert q_scale is not None
# for calibration, q_scale is already used to dequantize the weights
total_scale = local_scale * self.q_scale
total_scale = local_scale * q_scale
else:
total_scale = local_scale
return total_scale

if indptr.ndim == 2:
# Single batch specialization. Use gemv kernels instead
assert indptr.shape[0] == 1
from mlc_llm.op import moe_matmul

out = moe_matmul.gemv(x, w, indptr)
fp32_out = nn.op.astype(out, dtype="float32")
total_scale = get_total_scale(self.runtime, self.weight_dtype, self.q_scale)
total_scale = nn.op.astype(total_scale, dtype="float32")
scaled_out = fp32_out * total_scale
return nn.op.astype(scaled_out, dtype="float16")
else:
workspace = nn.op.wrap_nested(
relax.op.builtin.alloc_tensor(
relax.ShapeExpr((4096 * 1024,)),
dtype="uint8",
runtime_device_index=0,
),
"relax.alloc_tensor",
)

return nn.op.extern(
func,
[
x,
w,
indptr,
workspace,
total_scale,
],
out=nn.Tensor.placeholder(
(batch_size, out_features),
dtype=self.weight_config.model_dtype,
),
)
batch_size, in_features = x.shape
num_local_experts, out_features, _ = self.q_weight.shape

if self.runtime == "max-calibration":
func = "cutlass.group_gemm_scale_fp16_sm90"
else:
a_format = self.activation_dtype.split("_")[0]
w_format = self.weight_dtype.split("_")[0]
func = f"cutlass.group_gemm_{a_format}_{w_format}_fp16"

if self.runtime == "cast":
func = func + "_host_scale"
total_scale = 1.0
else:
total_scale = get_total_scale(self.runtime, self.weight_dtype, self.q_scale)
total_scale = nn.op.astype(total_scale, dtype="float32")

return nn.op.extern(
func,
[
x,
w,
indptr,
workspace,
total_scale,
],
out=nn.Tensor.placeholder(
(batch_size, out_features),
dtype=self.weight_config.model_dtype,
),
)


# TODO(csullivan): Refactor Linear and MixtralExperts to shared base with common code
Expand Down