Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,12 @@ def __init__(
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

# cutlass path
self.is_fp8_w8a8_sm90 = quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant)
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant)
self.use_cutlass = not self.block_quant and (
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
or self.is_fp8_w8a8_sm100)
self.use_cutlass = not self.block_quant and (self.is_fp8_w8a8_sm90
or self.is_fp8_w8a8_sm100)
self.disable_expert_map = False

def create_weights(self, layer: torch.nn.Module, num_experts: int,
Expand Down Expand Up @@ -1013,8 +1014,9 @@ def apply(
elif self.use_cutlass:
assert self.moe_quant_config is not None

# small-batch fallback on SM100
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
# SM90 or small-batch fallback on SM100
if self.is_fp8_w8a8_sm90 or (self.is_fp8_w8a8_sm100
and topk_ids.shape[0] <= 8):
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
return fused_experts(
Expand Down
Loading