Skip to content
341 changes: 238 additions & 103 deletions benchmarks/routines/moe.py

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions benchmarks/routines/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,13 +466,15 @@ def calculate_moe_tflops(
num_experts: int,
top_k: int,
time_ms: float,
is_gated: bool = True,
) -> float:
"""
Calculate TFLOPS for MOE operation.

MOE computation involves:
1. First GEMM: [num_tokens, hidden_size] x [num_experts, hidden_size, 2*intermediate_size]
2. Activation function (SwiGLU gate)
1. First GEMM: [num_tokens, hidden_size] x [num_experts, hidden_size, w1_cols]
where w1_cols = 2*intermediate_size (gated) or intermediate_size (non-gated)
2. Activation function (SwiGLU gate or ReLU2)
3. Second GEMM: [num_tokens, intermediate_size] x [num_experts, intermediate_size, hidden_size]

For each token, we only compute for top_k experts.
Expand All @@ -484,15 +486,17 @@ def calculate_moe_tflops(
num_experts: Total number of experts
top_k: Number of experts per token
time_ms: Execution time in milliseconds
is_gated: Whether activation is gated (SwiGLU/GeGLU) or non-gated (ReLU2)

Returns:
TFLOPS value
"""
_ = num_experts # kept for backward compatibility

# FLOPS per token per expert
w1_cols = (2 if is_gated else 1) * intermediate_size
flops_per_token_per_expert = (
2 * hidden_size * 2 * intermediate_size # First GEMM
2 * hidden_size * w1_cols # First GEMM
+ 2 * intermediate_size * hidden_size # Second GEMM
)

Expand All @@ -515,6 +519,7 @@ def calculate_moe_kernel_bandwidth(
routing_logits_dtype: Optional[torch.dtype] = torch.float32,
active_experts: Optional[int] = None,
verbose: int = 0,
is_gated: bool = True,
) -> float:
"""
Calculate memory bandwidth for MOE kernel operation in TB/sec.
Expand Down Expand Up @@ -573,8 +578,9 @@ def get_effective_bytes(
)

# Weight memory
w1_cols = (2 if is_gated else 1) * intermediate_size
weight_bytes_per_expert = (
2 * intermediate_size * hidden_size * weight_bytes_per_element # gemm1
w1_cols * hidden_size * weight_bytes_per_element # gemm1
+ hidden_size * intermediate_size * weight_bytes_per_element # gemm2
)
if active_experts is not None:
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
from .fused_moe import (
cute_dsl_fused_moe_nvfp4 as cute_dsl_fused_moe_nvfp4,
CuteDslMoEWrapper as CuteDslMoEWrapper,
b12x_fused_moe as b12x_fused_moe,
B12xMoEWrapper as B12xMoEWrapper,
)
from .gdn_prefill import chunk_gated_delta_rule as chunk_gated_delta_rule
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
Expand Down
26 changes: 26 additions & 0 deletions flashinfer/cute_dsl/fp4_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,3 +1533,29 @@ def silu_mul_quantize_block_fp4(
activated = silu_mul_16(gate, up)
block_max = max_abs_16(activated)
return quantize_block_fp4(activated, block_max, global_scale_val)


# =============================================================================
# ReLU2 Activation — ReLU(x)² for non-gated MoE (Nemotron-Super)
# =============================================================================


@cute.jit
def relu2_16(x: cute.Tensor) -> cute.Tensor:
"""Compute ReLU²(x) = max(0, x)² for 16 float32 values."""
out = cute.make_rmem_tensor((16,), Float32)
for i in cutlass.range_constexpr(16):
v = fmax_f32(x[i], Float32(0.0))
out[i] = v * v
return out


@cute.jit
def relu2_quantize_block_fp4(
x: cute.Tensor,
global_scale_val: Float32,
) -> Tuple[Uint64, Uint8]:
"""Fused ReLU² + FP4 quantize for 16 float32 values."""
activated = relu2_16(x)
block_max = max_abs_16(activated)
return quantize_block_fp4(activated, block_max, global_scale_val)
23 changes: 22 additions & 1 deletion flashinfer/cute_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ctypes
import functools
import importlib.util
import warnings
from typing import Union, Tuple

import cutlass
Expand Down Expand Up @@ -123,7 +124,27 @@ def get_max_active_clusters(cluster_size: int) -> int:
Returns:
Maximum number of active clusters supported by hardware.
"""
return get_hardware_info().get_max_active_clusters(cluster_size)
try:
return get_hardware_info().get_max_active_clusters(cluster_size)
except Exception as exc:
# nvidia_cutlass_dsl's hardware probe (cuKernelGetFunction) can fail
# in spawned subprocesses (e.g. vLLM EngineCore) when the CUDA driver
# API context is not current at first use, even if the PyTorch CUDA
# runtime is initialised. Fall back to the GPU's physical SM count,
# which is a safe upper bound: callers that clamp to sm_count (such
# as the SM120 MoE dispatch's ``min(get_max_active_clusters(1),
# sm_count)``) are unaffected; other callers under-parallelize
# slightly when per-CTA resources allow more than one cluster per
# SM, but never over-request (which could deadlock a resident grid).
warnings.warn(
f"cutlass.get_max_active_clusters failed "
f"({type(exc).__name__}: {exc}); falling back to sm_count. "
f"This can happen in spawned subprocesses where the CUDA driver "
f"API context is not current.",
RuntimeWarning,
stacklevel=2,
)
return get_num_sm(torch.device("cuda"))


# WAR for CuTeDSL make_ptr implementation for flashinfer
Expand Down
4 changes: 4 additions & 0 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from .cute_dsl import (
cute_dsl_fused_moe_nvfp4,
CuteDslMoEWrapper,
b12x_fused_moe,
B12xMoEWrapper,
)

_cute_dsl_available = True
Expand Down Expand Up @@ -84,4 +86,6 @@
__all__ += [
"cute_dsl_fused_moe_nvfp4",
"CuteDslMoEWrapper",
"b12x_fused_moe",
"B12xMoEWrapper",
]
6 changes: 6 additions & 0 deletions flashinfer/fused_moe/cute_dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
cute_dsl_fused_moe_nvfp4,
CuteDslMoEWrapper,
)
from .b12x_moe import (
b12x_fused_moe,
B12xMoEWrapper,
)

__all__ = [
"is_cute_dsl_available",
Expand All @@ -32,4 +36,6 @@
__all__ += [
"cute_dsl_fused_moe_nvfp4",
"CuteDslMoEWrapper",
"b12x_fused_moe",
"B12xMoEWrapper",
]
Loading
Loading