diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8504ba73defb..de63b9817612 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -542,11 +542,12 @@ def __init__( self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() # cutlass path + self.is_fp8_w8a8_sm90 = quant_config._is_fp8_w8a8_sm90( + self.weight_quant, self.input_quant) self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( self.weight_quant, self.input_quant) - self.use_cutlass = not self.block_quant and ( - quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) - or self.is_fp8_w8a8_sm100) + self.use_cutlass = not self.block_quant and (self.is_fp8_w8a8_sm90 + or self.is_fp8_w8a8_sm100) self.disable_expert_map = False def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -1013,8 +1014,9 @@ def apply( elif self.use_cutlass: assert self.moe_quant_config is not None - # small-batch fallback on SM100 - if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: + # SM90 or small-batch fallback on SM100 + if self.is_fp8_w8a8_sm90 or (self.is_fp8_w8a8_sm100 + and topk_ids.shape[0] <= 8): from vllm.model_executor.layers.fused_moe import fused_experts assert per_act_token == per_channel_quant return fused_experts(