Skip to content

Commit 5aa2869

Browse files
committed
Added quant benchmark
1 parent fd701cf commit 5aa2869

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed

benchmarks/kernels/benchmark_quant.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import random
2+
import time
3+
4+
import torch
5+
6+
from vllm import _custom_ops as ops
7+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
8+
9+
10+
@torch.inference_mode()
11+
def main(num_tokens: int,
12+
hidden_size: int,
13+
static_scale: bool,
14+
quant_dtype: torch.dtype,
15+
dtype: torch.dtype,
16+
seed: int = 0,
17+
do_profile: bool = False,
18+
num_warmup_iters: int = 5,
19+
num_iters: int = 100) -> None:
20+
random.seed(seed)
21+
torch.random.manual_seed(seed)
22+
if torch.cuda.is_available():
23+
torch.cuda.manual_seed(seed)
24+
torch.set_default_device("cuda")
25+
26+
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
27+
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
28+
29+
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
30+
torch.cuda.synchronize()
31+
if profile:
32+
torch.cuda.cudart().cudaProfilerStart()
33+
start_time = time.perf_counter()
34+
35+
for _ in range(num_iters):
36+
if quant_dtype == torch.int8:
37+
ops.scaled_int8_quant(x, scale)
38+
else:
39+
ops.scaled_fp8_quant(x, scale)
40+
torch.cuda.synchronize()
41+
42+
end_time = time.perf_counter()
43+
if profile:
44+
torch.cuda.cudart().cudaProfilerStart()
45+
return (end_time - start_time) / num_iters
46+
47+
# Warmup.
48+
print("Warming up...")
49+
run_benchmark = run_cuda_benchmark
50+
run_benchmark(num_iters=num_warmup_iters, profile=False)
51+
52+
# Benchmark.
53+
if do_profile:
54+
latency = run_benchmark(num_iters=1, profile=True)
55+
else:
56+
latency = run_benchmark(num_iters=num_iters, profile=False)
57+
print(f"Kernel running time: {latency * 1000000:.3f} us")
58+
59+
60+
if __name__ == '__main__':
61+
62+
def to_torch_dtype(dt):
63+
if dt == "int8":
64+
return torch.int8
65+
if dt == "fp8":
66+
return torch.float8_e4m3fn
67+
raise ValueError(f"Unsupported dtype: {dt}")
68+
69+
parser = FlexibleArgumentParser(
70+
description="Benchmark the quantization (fp8 or int8) kernel.")
71+
parser.add_argument("--num-tokens", type=int, default=4096)
72+
parser.add_argument("--hidden-size", type=int, default=8192)
73+
parser.add_argument("--static-scale", action="store_true")
74+
parser.add_argument("--quant-dtype",
75+
type=str,
76+
choices=["fp8", "int8"],
77+
default="int8")
78+
parser.add_argument("--dtype",
79+
type=str,
80+
choices=["half", "bfloat16", "float"],
81+
default="half")
82+
83+
parser.add_argument("--seed", type=int, default=0)
84+
parser.add_argument("--profile", action="store_true")
85+
parser.add_argument("--num-warmup-iters", type=int, default=5)
86+
parser.add_argument("--num-iters",
87+
type=int,
88+
default=100,
89+
help="Number of benchmark iterations. "
90+
"If --profile is set, this number is ignored")
91+
92+
args = parser.parse_args()
93+
print(args)
94+
95+
main(num_tokens=args.num_tokens,
96+
hidden_size=args.hidden_size,
97+
static_scale=args.static_scale,
98+
quant_dtype=to_torch_dtype(args.quant_dtype),
99+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
100+
seed=args.seed,
101+
do_profile=args.profile,
102+
num_warmup_iters=args.num_warmup_iters,
103+
num_iters=args.num_iters)

0 commit comments

Comments
 (0)