Skip to content

Commit 491dd3c

Browse files
committed
upd
1 parent 0a36fec commit 491dd3c

File tree

3 files changed

+440
-235
lines changed

3 files changed

+440
-235
lines changed

benchmarks/bench_sampling.py

Lines changed: 194 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -52,200 +52,200 @@ def init_seed_top_p_sampling(*args, **kwargs):
5252

5353
@torch.inference_mode()
5454
def main():
55-
# print("---")
56-
# print("naive sampling")
57-
# for vocab_size in [128512]:
58-
# for batch_size in [1, 16, 32, 64, 128, 256, 512]:
59-
# for distrib in [
60-
# normal_distribution(1),
61-
# normal_distribution(5),
62-
# gumbel_distribution(0.1),
63-
# gumbel_distribution(1),
64-
# ]:
65-
# for deterministic in [True, False]:
66-
# logits = distrib((batch_size, vocab_size), device="cuda")
67-
# probs = torch.softmax(logits, dim=-1)
68-
# samples = torch.zeros(
69-
# batch_size, dtype=torch.int32, device=probs.device
70-
# )
71-
# measurements = bench_gpu_time(
72-
# lambda: init_seed_sampling(probs, deterministic=deterministic),
73-
# dry_run_time_ms=100,
74-
# repeat_time_ms=1000,
75-
# )
76-
# ms = np.median(measurements)
77-
78-
# io = (
79-
# probs.numel() * probs.element_size()
80-
# + samples.numel() * samples.element_size()
81-
# )
82-
# bandwidth = io * 1e-6 / ms
83-
# print(
84-
# f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
85-
# )
86-
87-
# print("---")
88-
# print("top-k sampling")
89-
# for vocab_size in [128512]:
90-
# for batch_size in [1, 16, 32, 64, 128, 256, 512]:
91-
# for distrib in [
92-
# normal_distribution(1),
93-
# normal_distribution(5),
94-
# gumbel_distribution(0.1),
95-
# gumbel_distribution(1),
96-
# ]:
97-
# for deterministic in [True, False]:
98-
# for k in [10, 100, 1000, 5000]:
99-
# logits = distrib((batch_size, vocab_size), device="cuda")
100-
# probs = torch.softmax(logits, dim=-1)
101-
# samples = torch.zeros(
102-
# batch_size, dtype=torch.int32, device=probs.device
103-
# )
104-
# measurements = bench_gpu_time(
105-
# lambda: init_seed_top_k_sampling(
106-
# probs, k, deterministic=deterministic
107-
# ),
108-
# dry_run_time_ms=100,
109-
# repeat_time_ms=1000,
110-
# )
111-
# ms = np.median(measurements)
112-
113-
# io = (
114-
# probs.numel() * probs.element_size()
115-
# + samples.numel() * samples.element_size()
116-
# )
117-
# bandwidth = io * 1e-6 / ms
118-
# print(
119-
# f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
120-
# )
121-
122-
# print("---")
123-
# print("top-p sampling")
124-
125-
# for vocab_size in [128512]:
126-
# for batch_size in [1, 16, 32, 64, 128, 256, 512]:
127-
# for distrib in [
128-
# normal_distribution(1),
129-
# normal_distribution(5),
130-
# gumbel_distribution(0.1),
131-
# gumbel_distribution(1),
132-
# ]:
133-
# for deterministic in [True, False]:
134-
# for p in [0.1, 0.5, 0.9]:
135-
# logits = distrib((batch_size, vocab_size), device="cuda")
136-
# probs = torch.softmax(logits, dim=-1)
137-
# samples = torch.zeros(
138-
# batch_size, dtype=torch.int32, device=probs.device
139-
# )
140-
# measurements = bench_gpu_time(
141-
# lambda: init_seed_top_p_sampling(
142-
# probs, p, deterministic=deterministic
143-
# ),
144-
# dry_run_time_ms=100,
145-
# repeat_time_ms=1000,
146-
# )
147-
# ms = np.median(measurements)
148-
149-
# io = (
150-
# probs.numel() * probs.element_size()
151-
# + samples.numel() * samples.element_size()
152-
# )
153-
# bandwidth = io * 1e-6 / ms
154-
# print(
155-
# f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
156-
# )
157-
158-
# print("---")
159-
# print("sampling from softmax(logits)")
160-
# for vocab_size in [128512]:
161-
# for batch_size in [1, 16, 32, 64, 128, 256, 512]:
162-
# for distrib in [
163-
# normal_distribution(1),
164-
# normal_distribution(5),
165-
# gumbel_distribution(0.1),
166-
# gumbel_distribution(1),
167-
# ]:
168-
# for deterministic in [True, False]:
169-
# logits = distrib((batch_size, vocab_size), device="cuda")
170-
# samples = torch.zeros(
171-
# batch_size, dtype=torch.int32, device=logits.device
172-
# )
173-
# measurements = bench_gpu_time(
174-
# lambda: init_seed_sampling_from_softmax_logits(
175-
# logits, samples, deterministic=deterministic
176-
# ),
177-
# dry_run_time_ms=100,
178-
# repeat_time_ms=1000,
179-
# )
180-
# ms = np.median(measurements)
181-
# io = (
182-
# logits.numel() * logits.element_size()
183-
# + samples.numel() * samples.element_size()
184-
# )
185-
# bandwidth = io * 1e-6 / ms
186-
# print(
187-
# f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
188-
# )
189-
190-
# print("---")
191-
# print("sampling from logits")
192-
# for vocab_size in [128512]:
193-
# for batch_size in [1, 16, 32, 64, 128, 256, 512]:
194-
# for distrib in [
195-
# normal_distribution(1),
196-
# normal_distribution(5),
197-
# gumbel_distribution(0.1),
198-
# gumbel_distribution(1),
199-
# ]:
200-
# for deterministic in [True, False]:
201-
# logits = distrib((batch_size, vocab_size), device="cuda")
202-
# samples = torch.zeros(
203-
# batch_size, dtype=torch.int32, device=logits.device
204-
# )
205-
# measurements = bench_gpu_time(
206-
# lambda: init_seed_sampling_from_logits(
207-
# logits, samples, deterministic=deterministic
208-
# ),
209-
# dry_run_time_ms=100,
210-
# repeat_time_ms=1000,
211-
# )
212-
# ms = np.median(measurements)
213-
214-
# io = (
215-
# logits.numel() * logits.element_size()
216-
# + samples.numel() * samples.element_size()
217-
# )
218-
# bandwidth = io * 1e-6 / ms
219-
# print(
220-
# f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
221-
# )
222-
223-
# print("---")
224-
# print("top-p renorm probs")
225-
# for vocab_size in [128512]:
226-
# for batch_size in [1, 16, 32, 64, 128, 256, 512]:
227-
# torch.manual_seed(42)
228-
# for distrib in [
229-
# normal_distribution(1),
230-
# normal_distribution(5),
231-
# gumbel_distribution(0.1),
232-
# gumbel_distribution(1),
233-
# ]:
234-
# for p in [0.1, 0.5, 0.9]:
235-
# logits = distrib((batch_size, vocab_size), device="cuda")
236-
# probs = torch.softmax(logits, dim=-1)
237-
# measurements = bench_gpu_time(
238-
# lambda: flashinfer.sampling.top_p_renorm_probs(probs, p),
239-
# dry_run_time_ms=100,
240-
# repeat_time_ms=1000,
241-
# )
242-
# ms = np.median(measurements)
243-
244-
# io = probs.numel() * probs.element_size() * 2
245-
# bandwidth = io * 1e-6 / ms
246-
# print(
247-
# f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
248-
# )
55+
print("---")
56+
print("naive sampling")
57+
for vocab_size in [128512]:
58+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
59+
for distrib in [
60+
normal_distribution(1),
61+
normal_distribution(5),
62+
gumbel_distribution(0.1),
63+
gumbel_distribution(1),
64+
]:
65+
for deterministic in [True, False]:
66+
logits = distrib((batch_size, vocab_size), device="cuda")
67+
probs = torch.softmax(logits, dim=-1)
68+
samples = torch.zeros(
69+
batch_size, dtype=torch.int32, device=probs.device
70+
)
71+
measurements = bench_gpu_time(
72+
lambda: init_seed_sampling(probs, deterministic=deterministic),
73+
dry_run_time_ms=100,
74+
repeat_time_ms=1000,
75+
)
76+
ms = np.median(measurements)
77+
78+
io = (
79+
probs.numel() * probs.element_size()
80+
+ samples.numel() * samples.element_size()
81+
)
82+
bandwidth = io * 1e-6 / ms
83+
print(
84+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
85+
)
86+
87+
print("---")
88+
print("top-k sampling")
89+
for vocab_size in [128512]:
90+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
91+
for distrib in [
92+
normal_distribution(1),
93+
normal_distribution(5),
94+
gumbel_distribution(0.1),
95+
gumbel_distribution(1),
96+
]:
97+
for deterministic in [True, False]:
98+
for k in [10, 100, 1000, 5000]:
99+
logits = distrib((batch_size, vocab_size), device="cuda")
100+
probs = torch.softmax(logits, dim=-1)
101+
samples = torch.zeros(
102+
batch_size, dtype=torch.int32, device=probs.device
103+
)
104+
measurements = bench_gpu_time(
105+
lambda: init_seed_top_k_sampling(
106+
probs, k, deterministic=deterministic
107+
),
108+
dry_run_time_ms=100,
109+
repeat_time_ms=1000,
110+
)
111+
ms = np.median(measurements)
112+
113+
io = (
114+
probs.numel() * probs.element_size()
115+
+ samples.numel() * samples.element_size()
116+
)
117+
bandwidth = io * 1e-6 / ms
118+
print(
119+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
120+
)
121+
122+
print("---")
123+
print("top-p sampling")
124+
125+
for vocab_size in [128512]:
126+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
127+
for distrib in [
128+
normal_distribution(1),
129+
normal_distribution(5),
130+
gumbel_distribution(0.1),
131+
gumbel_distribution(1),
132+
]:
133+
for deterministic in [True, False]:
134+
for p in [0.1, 0.5, 0.9]:
135+
logits = distrib((batch_size, vocab_size), device="cuda")
136+
probs = torch.softmax(logits, dim=-1)
137+
samples = torch.zeros(
138+
batch_size, dtype=torch.int32, device=probs.device
139+
)
140+
measurements = bench_gpu_time(
141+
lambda: init_seed_top_p_sampling(
142+
probs, p, deterministic=deterministic
143+
),
144+
dry_run_time_ms=100,
145+
repeat_time_ms=1000,
146+
)
147+
ms = np.median(measurements)
148+
149+
io = (
150+
probs.numel() * probs.element_size()
151+
+ samples.numel() * samples.element_size()
152+
)
153+
bandwidth = io * 1e-6 / ms
154+
print(
155+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
156+
)
157+
158+
print("---")
159+
print("sampling from softmax(logits)")
160+
for vocab_size in [128512]:
161+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
162+
for distrib in [
163+
normal_distribution(1),
164+
normal_distribution(5),
165+
gumbel_distribution(0.1),
166+
gumbel_distribution(1),
167+
]:
168+
for deterministic in [True, False]:
169+
logits = distrib((batch_size, vocab_size), device="cuda")
170+
samples = torch.zeros(
171+
batch_size, dtype=torch.int32, device=logits.device
172+
)
173+
measurements = bench_gpu_time(
174+
lambda: init_seed_sampling_from_softmax_logits(
175+
logits, samples, deterministic=deterministic
176+
),
177+
dry_run_time_ms=100,
178+
repeat_time_ms=1000,
179+
)
180+
ms = np.median(measurements)
181+
io = (
182+
logits.numel() * logits.element_size()
183+
+ samples.numel() * samples.element_size()
184+
)
185+
bandwidth = io * 1e-6 / ms
186+
print(
187+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
188+
)
189+
190+
print("---")
191+
print("sampling from logits")
192+
for vocab_size in [128512]:
193+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
194+
for distrib in [
195+
normal_distribution(1),
196+
normal_distribution(5),
197+
gumbel_distribution(0.1),
198+
gumbel_distribution(1),
199+
]:
200+
for deterministic in [True, False]:
201+
logits = distrib((batch_size, vocab_size), device="cuda")
202+
samples = torch.zeros(
203+
batch_size, dtype=torch.int32, device=logits.device
204+
)
205+
measurements = bench_gpu_time(
206+
lambda: init_seed_sampling_from_logits(
207+
logits, samples, deterministic=deterministic
208+
),
209+
dry_run_time_ms=100,
210+
repeat_time_ms=1000,
211+
)
212+
ms = np.median(measurements)
213+
214+
io = (
215+
logits.numel() * logits.element_size()
216+
+ samples.numel() * samples.element_size()
217+
)
218+
bandwidth = io * 1e-6 / ms
219+
print(
220+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
221+
)
222+
223+
print("---")
224+
print("top-p renorm probs")
225+
for vocab_size in [128512]:
226+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
227+
torch.manual_seed(42)
228+
for distrib in [
229+
normal_distribution(1),
230+
normal_distribution(5),
231+
gumbel_distribution(0.1),
232+
gumbel_distribution(1),
233+
]:
234+
for p in [0.1, 0.5, 0.9]:
235+
logits = distrib((batch_size, vocab_size), device="cuda")
236+
probs = torch.softmax(logits, dim=-1)
237+
measurements = bench_gpu_time(
238+
lambda: flashinfer.sampling.top_p_renorm_probs(probs, p),
239+
dry_run_time_ms=100,
240+
repeat_time_ms=1000,
241+
)
242+
ms = np.median(measurements)
243+
244+
io = probs.numel() * probs.element_size() * 2
245+
bandwidth = io * 1e-6 / ms
246+
print(
247+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
248+
)
249249

250250
print("---")
251251
print("top-k renorm probs")

0 commit comments

Comments
 (0)