-
Notifications
You must be signed in to change notification settings - Fork 899
feat: preparing TRTLLM MoE backend to support more kernels #2794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,11 +34,24 @@ | |
| """ | ||
|
|
||
| import argparse | ||
| from contextlib import contextmanager | ||
| from dataclasses import dataclass | ||
| import numpy as np | ||
| import torch | ||
|
|
||
|
|
||
| @contextmanager | ||
| def cuda_profiler_range(name): | ||
| """Context manager for CUDA profiler + NVTX range.""" | ||
| torch.cuda.cudart().cudaProfilerStart() | ||
| torch.cuda.nvtx.range_push(name) | ||
| try: | ||
| yield | ||
| finally: | ||
| torch.cuda.nvtx.range_pop() | ||
| torch.cuda.cudart().cudaProfilerStop() | ||
|
|
||
|
|
||
| @dataclass | ||
| class DeepSeekConfig: | ||
| hidden_size: int = 7168 | ||
|
|
@@ -78,28 +91,44 @@ def is_sm100_family(): | |
| return props.major == 10 | ||
|
|
||
|
|
||
| def calc_tflops(n, ms, num_local_experts=None): | ||
| """Calculate TFLOPS for MoE computation. | ||
| NVFP4_BYTES = 9 / 16 # 0.5 bytes value + 1/16 byte block scale | ||
| BF16_BYTES = 2 | ||
|
|
||
|
|
||
| def calc_tflops(local_tokens, ms): | ||
| """Calculate TFLOPS using actual routed token count. | ||
|
|
||
| With EP, only tokens routed to local experts are computed. | ||
| Assumes uniform routing distribution across experts. | ||
| FC1: [M, H] x [H, 2I] | ||
| FC2: [M, I] x [I, H] | ||
| FLOPs = 2 * local_tokens * (H*2I + I*H) = 6 * local_tokens * H * I | ||
| """ | ||
| if num_local_experts is None: | ||
| num_local_experts = CFG.num_experts | ||
| H = CFG.hidden_size | ||
| I = CFG.intermediate_size | ||
| flops = local_tokens * (2 * H * 2 * I + 2 * I * H) | ||
| return flops / (ms * 1e-3) / 1e12 | ||
|
Comment on lines
+98
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard throughput computations against non-positive latency. If Proposed fix def calc_tflops(local_tokens, ms):
@@
- H = CFG.hidden_size
- I = CFG.intermediate_size
- flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
+ if ms <= 0:
+ return float("nan")
+ H = CFG.hidden_size
+ I = CFG.intermediate_size
+ flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
return flops / (ms * 1e-3) / 1e12
def calc_bw(local_tokens, active_experts, ms):
@@
- H = CFG.hidden_size
- I = CFG.intermediate_size
+ if ms <= 0:
+ return float("nan")
+ H = CFG.hidden_size
+ I = CFG.intermediate_sizeAlso applies to: 112-133 π§° Toolsπͺ Ruff (0.15.5)[error] 107-107: Ambiguous variable name: (E741) π€ Prompt for AI Agents |
||
|
|
||
| # Fraction of work done locally (assuming uniform distribution) | ||
| local_fraction = num_local_experts / CFG.num_experts | ||
|
|
||
| flops = ( | ||
| n | ||
| * CFG.top_k | ||
| * local_fraction # Only local expert pairs are computed | ||
| * ( | ||
| 2 * CFG.hidden_size * 2 * CFG.intermediate_size | ||
| + 2 * CFG.intermediate_size * CFG.hidden_size | ||
| ) | ||
| def calc_bw(local_tokens, active_experts, ms): | ||
| """Calculate achieved bandwidth in TB/s for MoE FC1 + FC2. | ||
|
|
||
| Weights are read once per active expert. | ||
| FC1: nvfp4 input [M, H] x nvfp4 weight [H, 2I] -> nvfp4 output [M, 2I] | ||
| FC2: nvfp4 input [M, I] x nvfp4 weight [I, H] -> bf16 output [M, H] | ||
| """ | ||
| H = CFG.hidden_size | ||
| I = CFG.intermediate_size | ||
|
|
||
| weight_bytes = active_experts * (H * 2 * I + I * H) * NVFP4_BYTES | ||
|
|
||
| # Note: each implementation differs in how data is read/written within the module. | ||
| # So here, we only account for the MoE module's read/write bytes. | ||
| act_bytes = ( | ||
| local_tokens * H * NVFP4_BYTES # FC1 input read | ||
| + local_tokens * H * BF16_BYTES # FC2 output write | ||
| ) | ||
| return flops / (ms * 1e-3) / 1e12 | ||
|
|
||
| total_bytes = weight_bytes + act_bytes | ||
| return total_bytes / (ms * 1e-3) / 1e12 | ||
|
|
||
|
|
||
| def interleave(x, gs=64): | ||
|
|
@@ -342,15 +371,16 @@ def run(x, x_sf, router_logits, routing_bias, topk_values, topk_indices): | |
| "topk_indices": ti, | ||
| } | ||
|
|
||
| times = bench_gpu_time( | ||
| run, | ||
| dry_run_iters=warmup, | ||
| repeat_iters=iters, | ||
| cold_l2_cache=True, | ||
| enable_cupti=use_cupti, | ||
| use_cuda_graph=use_cuda_graph, | ||
| input_kwargs=input_kwargs, | ||
| ) | ||
| with cuda_profiler_range("bench_cute_dsl"): | ||
| times = bench_gpu_time( | ||
| run, | ||
| dry_run_iters=warmup, | ||
| repeat_iters=iters, | ||
| cold_l2_cache=True, | ||
| enable_cupti=use_cupti, | ||
| use_cuda_graph=use_cuda_graph, | ||
| input_kwargs=input_kwargs, | ||
| ) | ||
| return np.median(times) | ||
|
|
||
|
|
||
|
|
@@ -449,15 +479,16 @@ def run(hidden, sf, router_logits, routing_bias, topk_values, topk_indices): | |
| "topk_indices": ti, | ||
| } | ||
|
|
||
| times = bench_gpu_time( | ||
| run, | ||
| dry_run_iters=warmup, | ||
| repeat_iters=iters, | ||
| cold_l2_cache=True, | ||
| enable_cupti=use_cupti, | ||
| use_cuda_graph=use_cuda_graph, | ||
| input_kwargs=input_kwargs, | ||
| ) | ||
| with cuda_profiler_range("bench_cutlass"): | ||
| times = bench_gpu_time( | ||
| run, | ||
| dry_run_iters=warmup, | ||
| repeat_iters=iters, | ||
| cold_l2_cache=True, | ||
| enable_cupti=use_cupti, | ||
| use_cuda_graph=use_cuda_graph, | ||
| input_kwargs=input_kwargs, | ||
| ) | ||
| return np.median(times) | ||
|
|
||
|
|
||
|
|
@@ -575,15 +606,16 @@ def run(routing_logits, routing_bias, hidden_states, hidden_states_scale): | |
| "hidden_states_scale": hsc, | ||
| } | ||
|
|
||
| times = bench_gpu_time( | ||
| run, | ||
| dry_run_iters=warmup, | ||
| repeat_iters=iters, | ||
| cold_l2_cache=True, | ||
| enable_cupti=use_cupti, | ||
| use_cuda_graph=use_cuda_graph, | ||
| input_kwargs=input_kwargs, | ||
| ) | ||
| with cuda_profiler_range("bench_trtllm"): | ||
| times = bench_gpu_time( | ||
| run, | ||
| dry_run_iters=warmup, | ||
| repeat_iters=iters, | ||
| cold_l2_cache=True, | ||
| enable_cupti=use_cupti, | ||
| use_cuda_graph=use_cuda_graph, | ||
| input_kwargs=input_kwargs, | ||
| ) | ||
| return np.median(times) | ||
|
|
||
|
|
||
|
|
@@ -817,6 +849,7 @@ class BenchResult: | |
| tokens: int | ||
| latency_ms: float | ||
| tflops: float | ||
| bw_tb_s: float | ||
|
|
||
|
|
||
| def run_benchmark( | ||
|
|
@@ -935,15 +968,18 @@ def _benchmark_single( | |
| ), | ||
| } | ||
|
|
||
| # Build results | ||
| # Build results using actual routed token counts | ||
| local_tokens = histogram_record["local_tokens"] | ||
| active_experts = histogram_record["active_local_experts"] | ||
| results = [] | ||
| for backend, latency in lat.items(): | ||
| results.append( | ||
| BenchResult( | ||
| backend=backend, | ||
| tokens=n, | ||
| latency_ms=latency, | ||
| tflops=calc_tflops(n, latency, num_local), | ||
| tflops=calc_tflops(local_tokens, latency), | ||
| bw_tb_s=calc_bw(local_tokens, active_experts, latency), | ||
| ) | ||
| ) | ||
| return results, histogram_record | ||
|
|
@@ -957,9 +993,9 @@ def _print_header( | |
| routing_bias_scale, | ||
| ): | ||
| """Print benchmark header.""" | ||
| print("\n" + "=" * 120) | ||
| print("\n" + "=" * 142) | ||
| print(f"DeepSeek-V3 MoE Benchmark: CuteDSL vs CUTLASS vs TRTLLM (EP={ep_config})") | ||
| print("=" * 120) | ||
| print("=" * 142) | ||
| print( | ||
| f"Model: hidden={CFG.hidden_size}, intermediate={CFG.intermediate_size}, " | ||
| f"experts={CFG.num_experts}, top_k={CFG.top_k}" | ||
|
|
@@ -974,28 +1010,28 @@ def _print_header( | |
| f"Routing bias scale: {routing_bias_scale} " | ||
| f"(larger values tend to create expert imbalance)" | ||
| ) | ||
| print("-" * 120) | ||
| print("-" * 142) | ||
| print( | ||
| f"{'Tokens':>6} | " | ||
| f"{'CuteDSL':^15} | " | ||
| f"{'CUTLASS':^15} | " | ||
| f"{'TRTLLM':^15} | " | ||
| f"{'CuteDSL':^22} | " | ||
| f"{'CUTLASS':^22} | " | ||
| f"{'TRTLLM':^22} | " | ||
| f"{'Speedup (CuteDSL/X)':^18} | " | ||
| f"{'Winner':^8} | " | ||
| f"{'Active':^7} | " | ||
| f"{'Stats':^14}" | ||
| f"{'Tokens/slot':^14}" | ||
| ) | ||
| print( | ||
| f"{'':>6} | " | ||
| f"{'ms':>7} {'TFLOPS':>7} | " | ||
| f"{'ms':>7} {'TFLOPS':>7} | " | ||
| f"{'ms':>7} {'TFLOPS':>7} | " | ||
| f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " | ||
| f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " | ||
| f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " | ||
| f"{'CUTLASS':>9} {'TRTLLM':>9} | " | ||
| f"{'':^8} | " | ||
| f"{'experts':^7} | " | ||
| f"{'min/max/median':^14}" | ||
| f"{'min/median/max':^14}" | ||
| ) | ||
| print("-" * 120) | ||
| print("-" * 142) | ||
|
|
||
|
|
||
| def _print_row(results, histogram_record): | ||
|
|
@@ -1014,14 +1050,14 @@ def _print_row(results, histogram_record): | |
| active_experts = f"{histogram_record['active_local_experts']:>3}" | ||
| stats = ( | ||
| f"{histogram_record['min_count']:>3}/" | ||
| f"{histogram_record['max_count']:>3}/" | ||
| f"{histogram_record['median_count']:>7.2f}" | ||
| f"{histogram_record['median_count']:>5.1f}/" | ||
| f"{histogram_record['max_count']:>4}" | ||
| ) | ||
| print( | ||
| f"{cute.tokens:>6} | " | ||
| f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} | " | ||
| f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} | " | ||
| f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} | " | ||
| f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} {cute.bw_tb_s:>6.1f} | " | ||
| f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} {cutlass.bw_tb_s:>6.1f} | " | ||
| f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} {trtllm.bw_tb_s:>6.1f} | " | ||
| f"{speedup_cutlass:>8.2f}x {speedup_trtllm:>8.2f}x | " | ||
| f"{winner:^8} | " | ||
| f"{active_experts:>7} | " | ||
|
|
@@ -1031,8 +1067,12 @@ def _print_row(results, histogram_record): | |
|
|
||
| def _print_footer(ep_config, num_local): | ||
| """Print benchmark footer.""" | ||
| print("-" * 120) | ||
| print("-" * 142) | ||
| print("Speedup > 1.0 means CuteDSL is faster than that backend") | ||
| print( | ||
| f"TFLOPS/BW use actual routed token counts. " | ||
| f"BW assumes nvfp4 = {NVFP4_BYTES:.4f} B/elem, bf16 = {BF16_BYTES} B/elem." | ||
| ) | ||
|
|
||
|
|
||
| def _collect_expert_histogram(inputs, num_local, local_offset): | ||
|
|
@@ -1059,6 +1099,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset): | |
| ) | ||
| local_hist = expert_hist[local_offset : local_offset + num_local] | ||
| local_hist_f32 = local_hist.to(torch.float32) | ||
| local_tokens = int(local_hist.sum().item()) | ||
| active_local_experts = int((local_hist > 0).sum().item()) | ||
| if local_hist.numel() > 0: | ||
| min_count = int(local_hist.min().item()) | ||
|
|
@@ -1070,6 +1111,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset): | |
| median_count = 0.0 | ||
|
|
||
| return { | ||
| "local_tokens": local_tokens, | ||
| "active_local_experts": active_local_experts, | ||
| "min_count": min_count, | ||
| "max_count": max_count, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename single-letter
Ito avoid Ruff E741 lint errors.Iis flagged as ambiguous (E741). Rename to a descriptive identifier (e.g.,intermediate_size) in both helpers.Proposed fix
Also applies to: 119-123
π§° Tools
πͺ Ruff (0.15.5)
[error] 107-107: Ambiguous variable name:
I(E741)
π€ Prompt for AI Agents