Skip to content

Commit adcc5dd

Browse files
authored
perf: improve sampling/mask/softmax performance (part 1/2) (#2044)
<!-- .github/pull_request_template.md --> ## 📌 Description This is the first part of the performance improvement PR for sampling/mask/softmax operator, in this PR, we defer the cross thread reduction till the end of the loop (similar to how FA2 handles denominator) to reduce the number of shuffling and thread sync instructions. For the second part of the PR, we will implement the Radix TopK algorithm to improve top-k mask logits performance when K is small. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Added comprehensive benchmarking suite for sampling and softmax operations with performance comparison and visualization tools. * **Chores** * Optimized internal kernel execution strategies for improved performance efficiency. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 63cf562 commit adcc5dd

File tree

3 files changed

+431
-124
lines changed

3 files changed

+431
-124
lines changed

benchmarks/bench_sampling.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,86 @@ def main():
220220
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
221221
)
222222

223+
print("---")
224+
print("top-p renorm probs")
225+
for vocab_size in [128512]:
226+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
227+
torch.manual_seed(42)
228+
for distrib in [
229+
normal_distribution(1),
230+
normal_distribution(5),
231+
gumbel_distribution(0.1),
232+
gumbel_distribution(1),
233+
]:
234+
for p in [0.1, 0.5, 0.9]:
235+
logits = distrib((batch_size, vocab_size), device="cuda")
236+
probs = torch.softmax(logits, dim=-1)
237+
measurements = bench_gpu_time(
238+
lambda: flashinfer.sampling.top_p_renorm_probs(probs, p),
239+
dry_run_time_ms=100,
240+
repeat_time_ms=1000,
241+
)
242+
ms = np.median(measurements)
243+
244+
io = probs.numel() * probs.element_size() * 2
245+
bandwidth = io * 1e-6 / ms
246+
print(
247+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
248+
)
249+
250+
print("---")
251+
print("top-k renorm probs")
252+
for vocab_size in [128512]:
253+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
254+
torch.manual_seed(42)
255+
for distrib in [
256+
normal_distribution(1),
257+
normal_distribution(5),
258+
gumbel_distribution(0.1),
259+
gumbel_distribution(1),
260+
]:
261+
for k in [10, 100, 1000, 5000]:
262+
logits = distrib((batch_size, vocab_size), device="cuda")
263+
probs = torch.softmax(logits, dim=-1)
264+
measurements = bench_gpu_time(
265+
lambda: flashinfer.sampling.top_k_renorm_probs(probs, k),
266+
dry_run_time_ms=100,
267+
repeat_time_ms=1000,
268+
)
269+
ms = np.median(measurements)
270+
271+
io = probs.numel() * probs.element_size() * 2
272+
bandwidth = io * 1e-6 / ms
273+
print(
274+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
275+
)
276+
277+
print("---")
278+
print("top-k mask logits")
279+
for vocab_size in [128512]:
280+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
281+
torch.manual_seed(42)
282+
for distrib in [
283+
normal_distribution(1),
284+
normal_distribution(5),
285+
gumbel_distribution(0.1),
286+
gumbel_distribution(1),
287+
]:
288+
for k in [10, 100, 1000, 5000]:
289+
logits = distrib((batch_size, vocab_size), device="cuda")
290+
measurements = bench_gpu_time(
291+
lambda: flashinfer.sampling.top_k_mask_logits(logits, k),
292+
dry_run_time_ms=100,
293+
repeat_time_ms=1000,
294+
)
295+
ms = np.median(measurements)
296+
297+
io = logits.numel() * logits.element_size() * 2
298+
bandwidth = io * 1e-6 / ms
299+
print(
300+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
301+
)
302+
223303

224304
if __name__ == "__main__":
225305
main()

benchmarks/bench_softmax.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Benchmark script comparing torch.softmax vs flashinfer.softmax performance.
4+
Creates a heatmap showing speedup across different batch sizes and hidden dimensions.
5+
"""
6+
7+
import numpy as np
8+
import torch
9+
import matplotlib.pyplot as plt
10+
import seaborn as sns
11+
from typing import List, Tuple
12+
import flashinfer
13+
from flashinfer.testing.utils import bench_gpu_time
14+
15+
16+
@torch.inference_mode()
17+
def benchmark_torch_softmax(logits: torch.Tensor) -> float:
18+
"""Benchmark torch's native softmax."""
19+
measurements = bench_gpu_time(
20+
lambda: torch.softmax(logits, dim=-1),
21+
dry_run_time_ms=100,
22+
repeat_time_ms=1000,
23+
)
24+
return np.median(measurements)
25+
26+
27+
@torch.inference_mode()
28+
def benchmark_flashinfer_softmax(logits: torch.Tensor) -> float:
29+
"""Benchmark flashinfer's softmax."""
30+
measurements = bench_gpu_time(
31+
lambda: flashinfer.sampling.softmax(logits, temperature=None, enable_pdl=False),
32+
dry_run_time_ms=100,
33+
repeat_time_ms=1000,
34+
)
35+
return np.median(measurements)
36+
37+
38+
def run_benchmark(
39+
batch_sizes: List[int], hidden_sizes: List[int]
40+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
41+
"""
42+
Run benchmarks for all combinations of batch_size and hidden_size.
43+
44+
Returns:
45+
torch_times: 2D array of torch softmax times (ms)
46+
flashinfer_times: 2D array of flashinfer softmax times (ms)
47+
speedups: 2D array of speedup ratios (torch_time / flashinfer_time)
48+
"""
49+
n_batch = len(batch_sizes)
50+
n_hidden = len(hidden_sizes)
51+
52+
torch_times = np.zeros((n_batch, n_hidden))
53+
flashinfer_times = np.zeros((n_batch, n_hidden))
54+
speedups = np.zeros((n_batch, n_hidden))
55+
56+
print("Running benchmarks...")
57+
print("=" * 100)
58+
print(
59+
f"{'Batch Size':<12} {'Hidden Size':<12} {'Torch (ms)':<15} "
60+
f"{'FlashInfer (ms)':<18} {'Speedup':<10} {'Bandwidth (GB/s)':<18}"
61+
)
62+
print("=" * 100)
63+
64+
for i, batch_size in enumerate(batch_sizes):
65+
for j, hidden_size in enumerate(hidden_sizes):
66+
# Generate random logits
67+
torch.manual_seed(42)
68+
logits = torch.randn(
69+
batch_size, hidden_size, device="cuda", dtype=torch.float32
70+
)
71+
72+
# Benchmark torch softmax
73+
torch_time_ms = benchmark_torch_softmax(logits)
74+
torch_times[i, j] = torch_time_ms
75+
76+
# Benchmark flashinfer softmax
77+
flashinfer_time_ms = benchmark_flashinfer_softmax(logits)
78+
flashinfer_times[i, j] = flashinfer_time_ms
79+
80+
# Calculate speedup
81+
speedup = torch_time_ms / flashinfer_time_ms
82+
speedups[i, j] = speedup
83+
84+
# Calculate effective bandwidth (read + write)
85+
io_bytes = logits.numel() * logits.element_size() * 2
86+
bandwidth_gb_s = io_bytes * 1e-6 / flashinfer_time_ms
87+
88+
print(
89+
f"{batch_size:<12} {hidden_size:<12} {torch_time_ms:<15.4f} "
90+
f"{flashinfer_time_ms:<18.4f} {speedup:<10.2f}x {bandwidth_gb_s:<18.2f}"
91+
)
92+
93+
print("=" * 100)
94+
return torch_times, flashinfer_times, speedups
95+
96+
97+
def plot_heatmap(
98+
speedups: np.ndarray,
99+
batch_sizes: List[int],
100+
hidden_sizes: List[int],
101+
save_path: str = "softmax_speedup_heatmap.png",
102+
):
103+
"""Create and save a heatmap of speedup values."""
104+
# Create figure
105+
fig, ax = plt.subplots(figsize=(12, 8))
106+
107+
# Create heatmap
108+
sns.heatmap(
109+
speedups,
110+
annot=True,
111+
fmt=".2f",
112+
cmap="RdYlGn",
113+
center=1.0,
114+
cbar_kws={"label": "Speedup (x)"},
115+
xticklabels=[f"{h // 1000}K" for h in hidden_sizes],
116+
yticklabels=batch_sizes,
117+
ax=ax,
118+
vmin=0.5, # Adjust color scale
119+
vmax=max(3.0, speedups.max()), # Dynamic upper bound
120+
)
121+
122+
ax.set_xlabel("Hidden Size", fontsize=12, fontweight="bold")
123+
ax.set_ylabel("Batch Size", fontsize=12, fontweight="bold")
124+
ax.set_title(
125+
"FlashInfer Softmax Speedup vs PyTorch (Higher is Better)",
126+
fontsize=14,
127+
fontweight="bold",
128+
pad=20,
129+
)
130+
131+
plt.tight_layout()
132+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
133+
print(f"\nHeatmap saved to: {save_path}")
134+
135+
# Also create a performance comparison plot
136+
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
137+
138+
# Plot 2: Speedup trends across batch sizes
139+
for j, hidden_size in enumerate(hidden_sizes):
140+
ax2.plot(
141+
batch_sizes,
142+
speedups[:, j],
143+
marker="o",
144+
label=f"Hidden={hidden_size // 1000}K",
145+
linewidth=2,
146+
)
147+
148+
ax2.set_xlabel("Batch Size", fontsize=12, fontweight="bold")
149+
ax2.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold")
150+
ax2.set_title("Speedup vs Batch Size", fontsize=13, fontweight="bold")
151+
ax2.set_xscale("log", base=2)
152+
ax2.grid(True, alpha=0.3)
153+
ax2.legend(fontsize=9)
154+
ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")
155+
156+
# Plot 1: Speedup trends across hidden sizes
157+
for i, batch_size in enumerate(batch_sizes[::2]): # Sample every other batch size
158+
idx = i * 2
159+
ax1.plot(
160+
[h // 1000 for h in hidden_sizes],
161+
speedups[idx, :],
162+
marker="s",
163+
label=f"Batch={batch_size}",
164+
linewidth=2,
165+
)
166+
167+
ax1.set_xlabel("Hidden Size (K)", fontsize=12, fontweight="bold")
168+
ax1.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold")
169+
ax1.set_title("Speedup vs Hidden Size", fontsize=13, fontweight="bold")
170+
ax1.grid(True, alpha=0.3)
171+
ax1.legend(fontsize=9)
172+
ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5)
173+
174+
plt.tight_layout()
175+
comparison_path = save_path.replace(".png", "_trends.png")
176+
plt.savefig(comparison_path, dpi=300, bbox_inches="tight")
177+
print(f"Trend plots saved to: {comparison_path}")
178+
179+
180+
def main():
181+
"""Main benchmark execution."""
182+
# Configuration
183+
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
184+
hidden_sizes = [32000, 64000, 128000, 256000]
185+
186+
print("=" * 100)
187+
print("FlashInfer vs PyTorch Softmax Benchmark")
188+
print("=" * 100)
189+
print(f"Batch sizes: {batch_sizes}")
190+
print(f"Hidden sizes: {hidden_sizes}")
191+
print(f"Device: {torch.cuda.get_device_name()}")
192+
print("=" * 100)
193+
print()
194+
195+
# Run benchmarks
196+
_, _, speedups = run_benchmark(batch_sizes, hidden_sizes)
197+
198+
# Print summary statistics
199+
print("\nSummary Statistics:")
200+
print("=" * 100)
201+
print(f"Average speedup: {np.mean(speedups):.2f}x")
202+
print(f"Median speedup: {np.median(speedups):.2f}x")
203+
print(f"Min speedup: {np.min(speedups):.2f}x")
204+
print(f"Max speedup: {np.max(speedups):.2f}x")
205+
print("=" * 100)
206+
207+
# Generate heatmap
208+
plot_heatmap(speedups, batch_sizes, hidden_sizes)
209+
210+
print("\nBenchmark complete!")
211+
212+
213+
if __name__ == "__main__":
214+
main()

0 commit comments

Comments
 (0)