diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index c779f1f1d39..80a5971a0e4 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -82,6 +82,7 @@ def dummy_func(*args, **kwargs): if _is_hip: from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages from aiter.ops.shuffle import shuffle_weight @@ -1062,19 +1063,20 @@ def maybe_apply_hip_fused_experts( if _use_aiter: assert not no_combine, f"{no_combine=} is not supported." if self.block_quant: - # TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being. - assert ( - activation == "silu" - ), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter" - return asm_moe( + return fused_moe( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, - layer.w13_weight_scale_inv, - layer.w2_weight_scale_inv, - block_shape=tuple(self.quant_config.weight_block_size), + w1_scale=layer.w13_weight_scale_inv, + w2_scale=layer.w2_weight_scale_inv, + quant_type=QuantType.per_128x128, + activation=( + ActivationType.Silu + if activation == "silu" + else ActivationType.Gelu + ), expert_mask=None, ) else: