Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 106 additions & 64 deletions benchmarks/bench_moe_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,24 @@
"""

import argparse
from contextlib import contextmanager
from dataclasses import dataclass
import numpy as np
import torch


@contextmanager
def cuda_profiler_range(name):
"""Context manager for CUDA profiler + NVTX range."""
torch.cuda.cudart().cudaProfilerStart()
torch.cuda.nvtx.range_push(name)
try:
yield
finally:
torch.cuda.nvtx.range_pop()
torch.cuda.cudart().cudaProfilerStop()


@dataclass
class DeepSeekConfig:
hidden_size: int = 7168
Expand Down Expand Up @@ -79,28 +92,44 @@ def is_sm100_family():
return props.major == 10


def calc_tflops(n, ms, num_local_experts=None):
"""Calculate TFLOPS for MoE computation.
NVFP4_BYTES = 9 / 16 # 0.5 bytes value + 1/16 byte block scale
BF16_BYTES = 2
Comment on lines +95 to +96
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Bandwidth calculation assumes NVFP4 activations, but BF16 is also supported.

The calc_bw function assumes FC1 input is NVFP4 (local_tokens * H * NVFP4_BYTES), but according to trtllm_fp4_block_scale_moe documentation, hidden_states can be bfloat16, mxfp8, or nvfp4. For BF16 inputs, the bandwidth calculation would underestimate actual memory traffic.

Consider either:

  1. Adding a parameter to specify input dtype
  2. Documenting that this calculation assumes NVFP4 inputs

Also applies to: 112-132

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_moe_deepseek.py` around lines 95 - 96, The bandwidth calc in
calc_bw currently assumes NVFP4 activations (using NVFP4_BYTES) which
underestimates traffic when hidden_states can be BF16 or other dtypes per
trtllm_fp4_block_scale_moe; update calc_bw to accept an input dtype parameter
(e.g., input_dtype or act_bytes) or a flag and use the corresponding byte-size
constant (BF16_BYTES, NVFP4_BYTES, etc.) when computing local_tokens * H *
<bytes>, and update any callers (or default behavior) so FC1 bandwidth uses the
selected dtype; alternatively, add a clear docstring on calc_bw and top-level
comment stating the NVFP4-only assumption if you choose not to change behavior.



def calc_tflops(local_tokens, ms):
"""Calculate TFLOPS using actual routed token count.

With EP, only tokens routed to local experts are computed.
Assumes uniform routing distribution across experts.
FC1: [M, H] x [H, 2I]
FC2: [M, I] x [I, H]
FLOPs = 2 * local_tokens * (H*2I + I*H) = 6 * local_tokens * H * I
"""
if num_local_experts is None:
num_local_experts = CFG.num_experts
H = CFG.hidden_size
I = CFG.intermediate_size
flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
return flops / (ms * 1e-3) / 1e12

# Fraction of work done locally (assuming uniform distribution)
local_fraction = num_local_experts / CFG.num_experts

flops = (
n
* CFG.top_k
* local_fraction # Only local expert pairs are computed
* (
2 * CFG.hidden_size * 2 * CFG.intermediate_size
+ 2 * CFG.intermediate_size * CFG.hidden_size
)
def calc_bw(local_tokens, active_experts, ms):
"""Calculate achieved bandwidth in TB/s for MoE FC1 + FC2.

Weights are read once per active expert.
FC1: nvfp4 input [M, H] x nvfp4 weight [H, 2I] -> nvfp4 output [M, 2I]
FC2: nvfp4 input [M, I] x nvfp4 weight [I, H] -> bf16 output [M, H]
"""
H = CFG.hidden_size
I = CFG.intermediate_size

weight_bytes = active_experts * (H * 2 * I + I * H) * NVFP4_BYTES

act_bytes = (
local_tokens * H * NVFP4_BYTES # FC1 input read
+ local_tokens * 2 * I * NVFP4_BYTES # FC1 output write
+ local_tokens * I * NVFP4_BYTES # FC2 input read
+ local_tokens * H * BF16_BYTES # FC2 output write
)
return flops / (ms * 1e-3) / 1e12

total_bytes = weight_bytes + act_bytes
return total_bytes / (ms * 1e-3) / 1e12


def interleave(x, gs=64):
Expand Down Expand Up @@ -343,15 +372,16 @@ def run(x, x_sf, router_logits, routing_bias, topk_values, topk_indices):
"topk_indices": ti,
}

times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
with cuda_profiler_range("bench_cute_dsl"):
times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
return np.median(times)


Expand Down Expand Up @@ -450,15 +480,16 @@ def run(hidden, sf, router_logits, routing_bias, topk_values, topk_indices):
"topk_indices": ti,
}

times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
with cuda_profiler_range("bench_cutlass"):
times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
return np.median(times)


Expand Down Expand Up @@ -576,15 +607,16 @@ def run(routing_logits, routing_bias, hidden_states, hidden_states_scale):
"hidden_states_scale": hsc,
}

times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
with cuda_profiler_range("bench_trtllm"):
times = bench_gpu_time(
run,
dry_run_iters=warmup,
repeat_iters=iters,
cold_l2_cache=True,
enable_cupti=use_cupti,
use_cuda_graph=use_cuda_graph,
input_kwargs=input_kwargs,
)
return np.median(times)


Expand Down Expand Up @@ -818,6 +850,7 @@ class BenchResult:
tokens: int
latency_ms: float
tflops: float
bw_tb_s: float


def run_benchmark(
Expand Down Expand Up @@ -936,15 +969,18 @@ def _benchmark_single(
),
}

# Build results
# Build results using actual routed token counts
local_tokens = histogram_record["local_tokens"]
active_experts = histogram_record["active_local_experts"]
results = []
for backend, latency in lat.items():
results.append(
BenchResult(
backend=backend,
tokens=n,
latency_ms=latency,
tflops=calc_tflops(n, latency, num_local),
tflops=calc_tflops(local_tokens, latency),
bw_tb_s=calc_bw(local_tokens, active_experts, latency),
)
)
return results, histogram_record
Expand All @@ -958,9 +994,9 @@ def _print_header(
routing_bias_scale,
):
"""Print benchmark header."""
print("\n" + "=" * 120)
print("\n" + "=" * 142)
print(f"DeepSeek-V3 MoE Benchmark: CuteDSL vs CUTLASS vs TRTLLM (EP={ep_config})")
print("=" * 120)
print("=" * 142)
print(
f"Model: hidden={CFG.hidden_size}, intermediate={CFG.intermediate_size}, "
f"experts={CFG.num_experts}, top_k={CFG.top_k}"
Expand All @@ -975,28 +1011,28 @@ def _print_header(
f"Routing bias scale: {routing_bias_scale} "
f"(larger values tend to create expert imbalance)"
)
print("-" * 120)
print("-" * 142)
print(
f"{'Tokens':>6} | "
f"{'CuteDSL':^15} | "
f"{'CUTLASS':^15} | "
f"{'TRTLLM':^15} | "
f"{'CuteDSL':^22} | "
f"{'CUTLASS':^22} | "
f"{'TRTLLM':^22} | "
f"{'Speedup (CuteDSL/X)':^18} | "
f"{'Winner':^8} | "
f"{'Active':^7} | "
f"{'Stats':^14}"
f"{'Tokens/slot':^14}"
)
print(
f"{'':>6} | "
f"{'ms':>7} {'TFLOPS':>7} | "
f"{'ms':>7} {'TFLOPS':>7} | "
f"{'ms':>7} {'TFLOPS':>7} | "
f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | "
f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | "
f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | "
f"{'CUTLASS':>9} {'TRTLLM':>9} | "
f"{'':^8} | "
f"{'experts':^7} | "
f"{'min/max/median':^14}"
f"{'min/median/max':^14}"
)
print("-" * 120)
print("-" * 142)


def _print_row(results, histogram_record):
Expand All @@ -1015,14 +1051,14 @@ def _print_row(results, histogram_record):
active_experts = f"{histogram_record['active_local_experts']:>3}"
stats = (
f"{histogram_record['min_count']:>3}/"
f"{histogram_record['max_count']:>3}/"
f"{histogram_record['median_count']:>7.2f}"
f"{histogram_record['median_count']:>5.1f}/"
f"{histogram_record['max_count']:>4}"
)
print(
f"{cute.tokens:>6} | "
f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} | "
f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} | "
f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} | "
f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} {cute.bw_tb_s:>6.1f} | "
f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} {cutlass.bw_tb_s:>6.1f} | "
f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} {trtllm.bw_tb_s:>6.1f} | "
f"{speedup_cutlass:>8.2f}x {speedup_trtllm:>8.2f}x | "
f"{winner:^8} | "
f"{active_experts:>7} | "
Expand All @@ -1032,8 +1068,12 @@ def _print_row(results, histogram_record):

def _print_footer(ep_config, num_local):
"""Print benchmark footer."""
print("-" * 120)
print("-" * 142)
print("Speedup > 1.0 means CuteDSL is faster than that backend")
print(
f"TFLOPS/BW use actual routed token counts. "
f"BW assumes nvfp4 = {NVFP4_BYTES:.4f} B/elem, bf16 = {BF16_BYTES} B/elem."
)


def _collect_expert_histogram(inputs, num_local, local_offset):
Expand All @@ -1060,6 +1100,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset):
)
local_hist = expert_hist[local_offset : local_offset + num_local]
local_hist_f32 = local_hist.to(torch.float32)
local_tokens = int(local_hist.sum().item())
active_local_experts = int((local_hist > 0).sum().item())
if local_hist.numel() > 0:
min_count = int(local_hist.min().item())
Expand All @@ -1071,6 +1112,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset):
median_count = 0.0

return {
"local_tokens": local_tokens,
"active_local_experts": active_local_experts,
"min_count": min_count,
"max_count": max_count,
Expand Down
Loading
Loading