diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5c66a472fa1c..3df46a7aeb4a 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -51,10 +51,10 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _use_aiter: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight - from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1 - if _is_cuda: from sgl_kernel import fused_marlin_moe @@ -292,7 +292,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> No max_w13_scales, requires_grad=False ) - if _use_aiter: + if self.weight_quant.strategy == QuantizationStrategy.CHANNEL and _use_aiter: with torch.no_grad(): # Pre-shuffle weights layer.w13_weight = torch.nn.Parameter( @@ -325,23 +325,33 @@ def apply( moe_runner_config = self.moe_runner_config - if ( - _use_aiter - and self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and moe_runner_config.apply_router_weight_on_input - ): + if _use_aiter and self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + assert not moe_runner_config.no_combine, "unsupported" topk_weights, topk_ids, _ = topk_output - output = rocm_fused_experts_tkw1( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=moe_runner_config.activation, - apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy - == QuantizationStrategy.CHANNEL, + if moe_runner_config.apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) + output = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + quant_type=QuantType.per_Token, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale,