Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 106 additions & 64 deletions benchmarks/bench_moe_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,24 @@
"""

import argparse
from contextlib import contextmanager
from dataclasses import dataclass
import numpy as np
import torch


@contextmanager
def cuda_profiler_range(name):
"""Context manager for CUDA profiler + NVTX range."""
torch.cuda.cudart().cudaProfilerStart()
torch.cuda.nvtx.range_push(name)
try:
yield
finally:
torch.cuda.nvtx.range_pop()
torch.cuda.cudart().cudaProfilerStop()


@dataclass
class DeepSeekConfig:
hidden_size: int = 7168
Expand Down Expand Up @@ -78,28 +91,44 @@ def is_sm100_family():
return props.major == 10


def calc_tflops(n, ms, num_local_experts=None):
"""Calculate TFLOPS for MoE computation.
NVFP4_BYTES = 9 / 16 # 0.5 bytes value + 1/16 byte block scale
BF16_BYTES = 2


def calc_tflops(local_tokens, ms):
"""Calculate TFLOPS using actual routed token count.

With EP, only tokens routed to local experts are computed.
Assumes uniform routing distribution across experts.
FC1: [M, H] x [H, 2I]
FC2: [M, I] x [I, H]
FLOPs = 2 * local_tokens * (H*2I + I*H) = 6 * local_tokens * H * I
"""
if num_local_experts is None:
num_local_experts = CFG.num_experts
H = CFG.hidden_size
I = CFG.intermediate_size
flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
Comment on lines +105 to +107
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Rename single-letter I to avoid Ruff E741 lint errors.

I is flagged as ambiguous (E741). Rename to a descriptive identifier (e.g., intermediate_size) in both helpers.

Proposed fix
-    H = CFG.hidden_size
-    I = CFG.intermediate_size
-    flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
+    H = CFG.hidden_size
+    intermediate_size = CFG.intermediate_size
+    flops = local_tokens * (
+        2 * H * 2 * intermediate_size + 2 * intermediate_size * H
+    )
@@
-    H = CFG.hidden_size
-    I = CFG.intermediate_size
+    H = CFG.hidden_size
+    intermediate_size = CFG.intermediate_size
 
-    weight_bytes = active_experts * (H * 2 * I + I * H) * NVFP4_BYTES
+    weight_bytes = (
+        active_experts * (H * 2 * intermediate_size + intermediate_size * H) * NVFP4_BYTES
+    )

Also applies to: 119-123

🧰 Tools
πŸͺ› Ruff (0.15.5)

[error] 107-107: Ambiguous variable name: I

(E741)

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_moe_deepseek.py` around lines 106 - 108, Rename the
ambiguous single-letter variable I to a descriptive name (e.g.,
intermediate_size) wherever it's used: replace I with intermediate_size in the
block that assigns H = CFG.hidden_size and I = CFG.intermediate_size and in the
subsequent FLOPS computation flops = local_tokens * (2 * H * 2 * I + 2 * I * H),
and also update the same rename in the later helper usages around the second
occurrence (lines 119-123) so all references (intermediate_size, H,
local_tokens, flops) remain consistent.

return flops / (ms * 1e-3) / 1e12
Comment on lines +98 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Guard throughput computations against non-positive latency.

If ms <= 0, both helpers can raise/diverge during division. Add a small guard to keep benchmarking output resilient.

Proposed fix
 def calc_tflops(local_tokens, ms):
@@
-    H = CFG.hidden_size
-    I = CFG.intermediate_size
-    flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
+    if ms <= 0:
+        return float("nan")
+    H = CFG.hidden_size
+    I = CFG.intermediate_size
+    flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
     return flops / (ms * 1e-3) / 1e12
 
 
 def calc_bw(local_tokens, active_experts, ms):
@@
-    H = CFG.hidden_size
-    I = CFG.intermediate_size
+    if ms <= 0:
+        return float("nan")
+    H = CFG.hidden_size
+    I = CFG.intermediate_size

Also applies to: 112-133

🧰 Tools
πŸͺ› Ruff (0.15.5)

[error] 107-107: Ambiguous variable name: I

(E741)

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_moe_deepseek.py` around lines 99 - 109, The calc_tflops
function (and the other throughput helper in the same file that computes
TFLOPS/throughput) must guard against non-positive latency: clamp ms to a small
positive epsilon (e.g. ms = max(ms, 1e-6)) or return 0.0 immediately when ms <=
0 to avoid divide-by-zero or negative throughput, and apply the same fix to the
corresponding helper at lines 112-133 so both functions consistently handle ms
<= 0.


# Fraction of work done locally (assuming uniform distribution)
local_fraction = num_local_experts / CFG.num_experts

flops = (
n
* CFG.top_k
* local_fraction # Only local expert pairs are computed
* (
2 * CFG.hidden_size * 2 * CFG.intermediate_size
+ 2 * CFG.intermediate_size * CFG.hidden_size
)
def calc_bw(local_tokens, active_experts, ms):
"""Calculate achieved bandwidth in TB/s for MoE FC1 + FC2.

Weights are read once per active expert.
FC1: nvfp4 input [M, H] x nvfp4 weight [H, 2I] -> nvfp4 output [M, 2I]
FC2: nvfp4 input [M, I] x nvfp4 weight [I, H] -> bf16 output [M, H]
"""
H = CFG.hidden_size
I = CFG.intermediate_size

weight_bytes = active_experts * (H * 2 * I + I * H) * NVFP4_BYTES

# Note: each implementation differs in how data is read/written within the module.
# So here, we only account for the MoE module's read/write bytes.
act_bytes = (
local_tokens * H * NVFP4_BYTES # FC1 input read
+ local_tokens * H * BF16_BYTES # FC2 output write
)
return flops / (ms * 1e-3) / 1e12

total_bytes = weight_bytes + act_bytes
return total_bytes / (ms * 1e-3) / 1e12


def interleave(x, gs=64):
Expand Down Expand Up @@ -342,15 +371,16 @@ def run(x, x_sf, router_logits, routing_bias, topk_values, topk_indices):
"topk_indices": ti,
}

times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
with cuda_profiler_range("bench_cute_dsl"):
times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
return np.median(times)


Expand Down Expand Up @@ -449,15 +479,16 @@ def run(hidden, sf, router_logits, routing_bias, topk_values, topk_indices):
"topk_indices": ti,
}

times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
with cuda_profiler_range("bench_cutlass"):
times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
return np.median(times)


Expand Down Expand Up @@ -575,15 +606,16 @@ def run(routing_logits, routing_bias, hidden_states, hidden_states_scale):
"hidden_states_scale": hsc,
}

times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
with cuda_profiler_range("bench_trtllm"):
times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
return np.median(times)


Expand Down Expand Up @@ -817,6 +849,7 @@ class BenchResult:
tokens: int
latency_ms: float
tflops: float
bw_tb_s: float


def run_benchmark(
Expand Down Expand Up @@ -935,15 +968,18 @@ def _benchmark_single(
),
}

# Build results
# Build results using actual routed token counts
local_tokens = histogram_record["local_tokens"]
active_experts = histogram_record["active_local_experts"]
results = []
for backend, latency in lat.items():
results.append(
BenchResult(
backend=backend,
tokens=n,
latency_ms=latency,
tflops=calc_tflops(n, latency, num_local),
tflops=calc_tflops(local_tokens, latency),
bw_tb_s=calc_bw(local_tokens, active_experts, latency),
)
)
return results, histogram_record
Expand All @@ -957,9 +993,9 @@ def _print_header(
routing_bias_scale,
):
"""Print benchmark header."""
print("\n" + "=" * 120)
print("\n" + "=" * 142)
print(f"DeepSeek-V3 MoE Benchmark: CuteDSL vs CUTLASS vs TRTLLM (EP={ep_config})")
print("=" * 120)
print("=" * 142)
print(
f"Model: hidden={CFG.hidden_size}, intermediate={CFG.intermediate_size}, "
f"experts={CFG.num_experts}, top_k={CFG.top_k}"
Expand All @@ -974,28 +1010,28 @@ def _print_header(
f"Routing bias scale: {routing_bias_scale} "
f"(larger values tend to create expert imbalance)"
)
print("-" * 120)
print("-" * 142)
print(
f"{'Tokens':>6} | "
f"{'CuteDSL':^15} | "
f"{'CUTLASS':^15} | "
f"{'TRTLLM':^15} | "
f"{'CuteDSL':^22} | "
f"{'CUTLASS':^22} | "
f"{'TRTLLM':^22} | "
f"{'Speedup (CuteDSL/X)':^18} | "
f"{'Winner':^8} | "
f"{'Active':^7} | "
f"{'Stats':^14}"
f"{'Tokens/slot':^14}"
)
print(
f"{'':>6} | "
f"{'ms':>7} {'TFLOPS':>7} | "
f"{'ms':>7} {'TFLOPS':>7} | "
f"{'ms':>7} {'TFLOPS':>7} | "
f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | "
f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | "
f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | "
f"{'CUTLASS':>9} {'TRTLLM':>9} | "
f"{'':^8} | "
f"{'experts':^7} | "
f"{'min/max/median':^14}"
f"{'min/median/max':^14}"
)
print("-" * 120)
print("-" * 142)


def _print_row(results, histogram_record):
Expand All @@ -1014,14 +1050,14 @@ def _print_row(results, histogram_record):
active_experts = f"{histogram_record['active_local_experts']:>3}"
stats = (
f"{histogram_record['min_count']:>3}/"
f"{histogram_record['max_count']:>3}/"
f"{histogram_record['median_count']:>7.2f}"
f"{histogram_record['median_count']:>5.1f}/"
f"{histogram_record['max_count']:>4}"
)
print(
f"{cute.tokens:>6} | "
f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} | "
f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} | "
f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} | "
f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} {cute.bw_tb_s:>6.1f} | "
f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} {cutlass.bw_tb_s:>6.1f} | "
f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} {trtllm.bw_tb_s:>6.1f} | "
f"{speedup_cutlass:>8.2f}x {speedup_trtllm:>8.2f}x | "
f"{winner:^8} | "
f"{active_experts:>7} | "
Expand All @@ -1031,8 +1067,12 @@ def _print_row(results, histogram_record):

def _print_footer(ep_config, num_local):
"""Print benchmark footer."""
print("-" * 120)
print("-" * 142)
print("Speedup > 1.0 means CuteDSL is faster than that backend")
print(
f"TFLOPS/BW use actual routed token counts. "
f"BW assumes nvfp4 = {NVFP4_BYTES:.4f} B/elem, bf16 = {BF16_BYTES} B/elem."
)


def _collect_expert_histogram(inputs, num_local, local_offset):
Expand All @@ -1059,6 +1099,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset):
)
local_hist = expert_hist[local_offset : local_offset + num_local]
local_hist_f32 = local_hist.to(torch.float32)
local_tokens = int(local_hist.sum().item())
active_local_experts = int((local_hist > 0).sum().item())
if local_hist.numel() > 0:
min_count = int(local_hist.min().item())
Expand All @@ -1070,6 +1111,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset):
median_count = 0.0

return {
"local_tokens": local_tokens,
"active_local_experts": active_local_experts,
"min_count": min_count,
"max_count": max_count,
Expand Down
Loading
Loading