Skip to content

Commit 92c15f7

Browse files
committed
Adding benchmark. Applying pre-commit
1 parent dc38bed commit 92c15f7

File tree

4 files changed

+653
-271
lines changed

4 files changed

+653
-271
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Benchmark script to measure the overhead of API logging at different levels.
4+
5+
This script creates decorated and undecorated versions of a test function
6+
(torch.matmul) and compares their performance to accurately measure logging overhead.
7+
8+
Why torch.matmul instead of bmm_fp8?
9+
- bmm_fp8 is already decorated in the FlashInfer source code
10+
- Using it would cause double-decoration and inaccurate results
11+
- torch.matmul gives us a clean baseline to measure pure decorator overhead
12+
13+
Usage:
14+
# Set the logging level before running
15+
export FLASHINFER_APILOG_LEVEL=2
16+
python bench_logging_overhead.py
17+
18+
# Or run with different levels
19+
FLASHINFER_APILOG_LEVEL=0 python bench_logging_overhead.py
20+
FLASHINFER_APILOG_LEVEL=1 python bench_logging_overhead.py
21+
FLASHINFER_APILOG_LEVEL=2 python bench_logging_overhead.py
22+
FLASHINFER_APILOG_LEVEL=3 python bench_logging_overhead.py
23+
24+
# Or use the helper script to run all levels
25+
bash benchmark_all_levels.sh
26+
"""
27+
28+
import os
29+
import sys
30+
import time
31+
import torch
32+
import numpy as np
33+
from typing import List, Tuple
34+
35+
# Get logging level BEFORE importing flashinfer
36+
LOGGING_LEVEL = int(os.environ.get("FLASHINFER_APILOG_LEVEL", "0"))
37+
LOG_DEST = os.environ.get("FLASHINFER_APILOG_DEST", "/tmp/flashinfer_benchmark_log.txt")
38+
39+
# Import the decorator
40+
try:
41+
from flashinfer.api_logging import flashinfer_api_log
42+
except ImportError as e:
43+
print(f"Error: Could not import flashinfer: {e}")
44+
print("Make sure flashinfer is installed.")
45+
exit(1)
46+
47+
48+
# Create two versions of a test function:
49+
# 1. Undecorated (baseline)
50+
# 2. Decorated (with logging)
51+
#
52+
# We use a simple torch.matmul instead of bmm_fp8 because bmm_fp8 is already
53+
# decorated in the source code, which would cause double-decoration.
54+
55+
56+
def test_matmul_undecorated(A, B):
57+
"""Undecorated version - baseline for comparison."""
58+
return torch.matmul(A, B)
59+
60+
61+
@flashinfer_api_log
62+
def test_matmul_decorated(A, B):
63+
"""Decorated version - with API logging."""
64+
return torch.matmul(A, B)
65+
66+
67+
class BenchmarkResults:
68+
"""Store and display benchmark results."""
69+
70+
def __init__(self):
71+
self.undecorated_times = []
72+
self.decorated_times = []
73+
74+
def set_undecorated(self, times: List[float]):
75+
"""Set benchmark results for undecorated function."""
76+
self.undecorated_times = times
77+
78+
def set_decorated(self, times: List[float]):
79+
"""Set benchmark results for decorated function."""
80+
self.decorated_times = times
81+
82+
def print_summary(self, logging_level: int):
83+
"""Print a summary of benchmark results."""
84+
print("\n" + "=" * 80)
85+
print("BENCHMARK RESULTS")
86+
print("=" * 80)
87+
88+
undecorated_mean = np.mean(self.undecorated_times)
89+
undecorated_std = np.std(self.undecorated_times)
90+
91+
decorated_mean = np.mean(self.decorated_times)
92+
decorated_std = np.std(self.decorated_times)
93+
94+
overhead_abs = (decorated_mean - undecorated_mean) * 1000 # ms
95+
overhead_pct = (
96+
((decorated_mean - undecorated_mean) / undecorated_mean * 100)
97+
if undecorated_mean > 0
98+
else 0
99+
)
100+
101+
print(
102+
f"\n{'Version':<20} {'Mean (ms)':<12} {'Std (ms)':<12} {'Median (ms)':<12}"
103+
)
104+
print("-" * 80)
105+
print(
106+
f"{'Undecorated':<20} {undecorated_mean * 1000:<12.4f} {undecorated_std * 1000:<12.4f} {np.median(self.undecorated_times) * 1000:<12.4f}"
107+
)
108+
print(
109+
f"{'Decorated':<20} {decorated_mean * 1000:<12.4f} {decorated_std * 1000:<12.4f} {np.median(self.decorated_times) * 1000:<12.4f}"
110+
)
111+
112+
print("\n" + "=" * 80)
113+
print("OVERHEAD ANALYSIS")
114+
print("=" * 80)
115+
print(f"\nLogging Level: {logging_level}")
116+
print(f"Absolute overhead: {overhead_abs:.4f} ms")
117+
print(f"Relative overhead: {overhead_pct:.2f}%")
118+
119+
print("\n" + "=" * 80)
120+
print("DETAILED STATISTICS")
121+
print("=" * 80)
122+
123+
print("\nUndecorated (baseline):")
124+
print(f" Mean: {undecorated_mean * 1000:.4f} ms")
125+
print(f" Median: {np.median(self.undecorated_times) * 1000:.4f} ms")
126+
print(f" Std: {undecorated_std * 1000:.4f} ms")
127+
print(f" Min: {np.min(self.undecorated_times) * 1000:.4f} ms")
128+
print(f" Max: {np.max(self.undecorated_times) * 1000:.4f} ms")
129+
130+
print("\nDecorated (with logging):")
131+
print(f" Mean: {decorated_mean * 1000:.4f} ms")
132+
print(f" Median: {np.median(self.decorated_times) * 1000:.4f} ms")
133+
print(f" Std: {decorated_std * 1000:.4f} ms")
134+
print(f" Min: {np.min(self.decorated_times) * 1000:.4f} ms")
135+
print(f" Max: {np.max(self.decorated_times) * 1000:.4f} ms")
136+
137+
138+
def setup_test_inputs(
139+
batch_size: int = 32,
140+
m: int = 512,
141+
n: int = 512,
142+
k: int = 512,
143+
device: str = "cuda:0",
144+
) -> Tuple[torch.Tensor, torch.Tensor]:
145+
"""
146+
Set up test inputs for matmul.
147+
148+
Parameters
149+
----------
150+
batch_size : int
151+
Batch size for the matrix multiplication
152+
m, n, k : int
153+
Matrix dimensions
154+
device : str
155+
Device to use
156+
157+
Returns
158+
-------
159+
A, B : torch.Tensor
160+
Input tensors for matrix multiplication
161+
"""
162+
# Create random tensors
163+
A = torch.randn(batch_size, m, k, dtype=torch.float16, device=device)
164+
B = torch.randn(batch_size, k, n, dtype=torch.float16, device=device)
165+
166+
return A, B
167+
168+
169+
def warmup(func, A, B, num_warmup: int = 10):
170+
"""Warmup the GPU and JIT compilation."""
171+
for _ in range(num_warmup):
172+
_ = func(A, B)
173+
torch.cuda.synchronize()
174+
175+
176+
def benchmark_function(
177+
func, func_name: str, A, B, num_iterations: int = 100
178+
) -> List[float]:
179+
"""
180+
Benchmark a specific function.
181+
182+
Parameters
183+
----------
184+
func : callable
185+
Function to benchmark
186+
func_name : str
187+
Name of the function (for display)
188+
A, B : torch.Tensor
189+
Input tensors for matrix multiplication
190+
num_iterations : int
191+
Number of iterations to run
192+
193+
Returns
194+
-------
195+
List[float]
196+
List of execution times in seconds
197+
"""
198+
print(f"\nBenchmarking: {func_name}")
199+
print(f" Running {num_iterations} iterations...")
200+
201+
times = []
202+
203+
for _ in range(num_iterations):
204+
# Synchronize before timing
205+
torch.cuda.synchronize()
206+
207+
# Time the execution
208+
start = time.perf_counter()
209+
_ = func(A, B)
210+
torch.cuda.synchronize()
211+
end = time.perf_counter()
212+
213+
elapsed = end - start
214+
times.append(elapsed)
215+
216+
print(f" Complete. Mean time: {np.mean(times) * 1000:.4f} ms")
217+
218+
return times
219+
220+
221+
def main():
222+
"""Main benchmark function."""
223+
print("=" * 80)
224+
print("FlashInfer API Logging Overhead Benchmark")
225+
print("=" * 80)
226+
227+
# Display logging configuration
228+
print("\nLogging Configuration:")
229+
print(f" FLASHINFER_APILOG_LEVEL = {LOGGING_LEVEL}")
230+
print(f" FLASHINFER_APILOG_DEST = {LOG_DEST}")
231+
232+
# Get level name
233+
level_names = {
234+
0: "No logging (zero-overhead)",
235+
1: "Function name only",
236+
2: "Name + inputs/outputs + metadata",
237+
3: "Name + inputs/outputs + metadata + statistics",
238+
}
239+
print(f" Level description: {level_names.get(LOGGING_LEVEL, 'Unknown')}")
240+
241+
# Check if CUDA is available
242+
if not torch.cuda.is_available():
243+
print("\nError: CUDA is not available. This benchmark requires a CUDA device.")
244+
exit(1)
245+
246+
device = "cuda:0"
247+
print(f"\nDevice: {device}")
248+
print(f"Device Name: {torch.cuda.get_device_name(device)}")
249+
250+
# Setup test inputs
251+
print("\nSetting up test inputs...")
252+
batch_size = 32
253+
m, n, k = 128, 128, 128
254+
print(f" Batch size: {batch_size}")
255+
print(f" Matrix dimensions: [{batch_size}, {m}, {k}] @ [{batch_size}, {k}, {n}]")
256+
257+
A, B = setup_test_inputs(batch_size, m, n, k, device)
258+
259+
# Benchmark parameters
260+
num_iterations = 100
261+
print("\nBenchmark parameters:")
262+
print(f" Iterations: {num_iterations}")
263+
print(" Warmup iterations: 10")
264+
265+
# Clear log file before starting
266+
if os.path.exists(LOG_DEST):
267+
os.remove(LOG_DEST)
268+
269+
print("\n" + "=" * 80)
270+
print("WARMUP PHASE")
271+
print("=" * 80)
272+
273+
# Warmup undecorated version
274+
print("\nWarming up undecorated version...")
275+
warmup(test_matmul_undecorated, A, B, num_warmup=10)
276+
print(" Complete.")
277+
278+
# Warmup decorated version
279+
print("\nWarming up decorated version...")
280+
warmup(test_matmul_decorated, A, B, num_warmup=10)
281+
print(" Complete.")
282+
283+
print("\n" + "=" * 80)
284+
print("BENCHMARK PHASE")
285+
print("=" * 80)
286+
287+
# Store results
288+
results = BenchmarkResults()
289+
290+
# Benchmark undecorated version
291+
undecorated_times = benchmark_function(
292+
test_matmul_undecorated, "Undecorated (baseline)", A, B, num_iterations
293+
)
294+
results.set_undecorated(undecorated_times)
295+
296+
# Benchmark decorated version
297+
decorated_times = benchmark_function(
298+
test_matmul_decorated,
299+
f"Decorated (logging level {LOGGING_LEVEL})",
300+
A,
301+
B,
302+
num_iterations,
303+
)
304+
results.set_decorated(decorated_times)
305+
306+
# Print summary
307+
results.print_summary(LOGGING_LEVEL)
308+
309+
# Check log file size
310+
if LOGGING_LEVEL > 0 and os.path.exists(LOG_DEST):
311+
log_size = os.path.getsize(LOG_DEST)
312+
print("\n" + "=" * 80)
313+
print("LOG FILE INFO")
314+
print("=" * 80)
315+
print(f"Log file: {LOG_DEST}")
316+
print(f"Log size: {log_size / 1024:.2f} KB ({log_size} bytes)")
317+
print(f"Iterations logged: {num_iterations}")
318+
print(f"Bytes per iteration: {log_size / num_iterations:.2f}")
319+
320+
# Cleanup option
321+
cleanup_log = os.environ.get("CLEANUP_LOG", "true").lower() == "true"
322+
if cleanup_log:
323+
os.remove(LOG_DEST)
324+
print("\n Log file removed (set CLEANUP_LOG=false to keep it)")
325+
else:
326+
print(f"\n Log file preserved at {LOG_DEST}")
327+
328+
print("\n" + "=" * 80)
329+
print("RECOMMENDATIONS")
330+
print("=" * 80)
331+
print("\nTo benchmark other levels, run:")
332+
for level in [0, 1, 2, 3]:
333+
if level != LOGGING_LEVEL:
334+
print(f" FLASHINFER_APILOG_LEVEL={level} python {sys.argv[0]}")
335+
336+
print("\n" + "=" * 80)
337+
print("Benchmark complete!")
338+
print("=" * 80)
339+
340+
341+
if __name__ == "__main__":
342+
try:
343+
main()
344+
except KeyboardInterrupt:
345+
print("\n\nBenchmark interrupted by user.")
346+
except Exception as e:
347+
print(f"\n\nError during benchmark: {e}")
348+
import traceback
349+
350+
traceback.print_exc()

0 commit comments

Comments
 (0)