@@ -52,200 +52,200 @@ def init_seed_top_p_sampling(*args, **kwargs):
5252
5353@torch .inference_mode ()
5454def main ():
55- # print("---")
56- # print("naive sampling")
57- # for vocab_size in [128512]:
58- # for batch_size in [1, 16, 32, 64, 128, 256, 512]:
59- # for distrib in [
60- # normal_distribution(1),
61- # normal_distribution(5),
62- # gumbel_distribution(0.1),
63- # gumbel_distribution(1),
64- # ]:
65- # for deterministic in [True, False]:
66- # logits = distrib((batch_size, vocab_size), device="cuda")
67- # probs = torch.softmax(logits, dim=-1)
68- # samples = torch.zeros(
69- # batch_size, dtype=torch.int32, device=probs.device
70- # )
71- # measurements = bench_gpu_time(
72- # lambda: init_seed_sampling(probs, deterministic=deterministic),
73- # dry_run_time_ms=100,
74- # repeat_time_ms=1000,
75- # )
76- # ms = np.median(measurements)
77-
78- # io = (
79- # probs.numel() * probs.element_size()
80- # + samples.numel() * samples.element_size()
81- # )
82- # bandwidth = io * 1e-6 / ms
83- # print(
84- # f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
85- # )
86-
87- # print("---")
88- # print("top-k sampling")
89- # for vocab_size in [128512]:
90- # for batch_size in [1, 16, 32, 64, 128, 256, 512]:
91- # for distrib in [
92- # normal_distribution(1),
93- # normal_distribution(5),
94- # gumbel_distribution(0.1),
95- # gumbel_distribution(1),
96- # ]:
97- # for deterministic in [True, False]:
98- # for k in [10, 100, 1000, 5000]:
99- # logits = distrib((batch_size, vocab_size), device="cuda")
100- # probs = torch.softmax(logits, dim=-1)
101- # samples = torch.zeros(
102- # batch_size, dtype=torch.int32, device=probs.device
103- # )
104- # measurements = bench_gpu_time(
105- # lambda: init_seed_top_k_sampling(
106- # probs, k, deterministic=deterministic
107- # ),
108- # dry_run_time_ms=100,
109- # repeat_time_ms=1000,
110- # )
111- # ms = np.median(measurements)
112-
113- # io = (
114- # probs.numel() * probs.element_size()
115- # + samples.numel() * samples.element_size()
116- # )
117- # bandwidth = io * 1e-6 / ms
118- # print(
119- # f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
120- # )
121-
122- # print("---")
123- # print("top-p sampling")
124-
125- # for vocab_size in [128512]:
126- # for batch_size in [1, 16, 32, 64, 128, 256, 512]:
127- # for distrib in [
128- # normal_distribution(1),
129- # normal_distribution(5),
130- # gumbel_distribution(0.1),
131- # gumbel_distribution(1),
132- # ]:
133- # for deterministic in [True, False]:
134- # for p in [0.1, 0.5, 0.9]:
135- # logits = distrib((batch_size, vocab_size), device="cuda")
136- # probs = torch.softmax(logits, dim=-1)
137- # samples = torch.zeros(
138- # batch_size, dtype=torch.int32, device=probs.device
139- # )
140- # measurements = bench_gpu_time(
141- # lambda: init_seed_top_p_sampling(
142- # probs, p, deterministic=deterministic
143- # ),
144- # dry_run_time_ms=100,
145- # repeat_time_ms=1000,
146- # )
147- # ms = np.median(measurements)
148-
149- # io = (
150- # probs.numel() * probs.element_size()
151- # + samples.numel() * samples.element_size()
152- # )
153- # bandwidth = io * 1e-6 / ms
154- # print(
155- # f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
156- # )
157-
158- # print("---")
159- # print("sampling from softmax(logits)")
160- # for vocab_size in [128512]:
161- # for batch_size in [1, 16, 32, 64, 128, 256, 512]:
162- # for distrib in [
163- # normal_distribution(1),
164- # normal_distribution(5),
165- # gumbel_distribution(0.1),
166- # gumbel_distribution(1),
167- # ]:
168- # for deterministic in [True, False]:
169- # logits = distrib((batch_size, vocab_size), device="cuda")
170- # samples = torch.zeros(
171- # batch_size, dtype=torch.int32, device=logits.device
172- # )
173- # measurements = bench_gpu_time(
174- # lambda: init_seed_sampling_from_softmax_logits(
175- # logits, samples, deterministic=deterministic
176- # ),
177- # dry_run_time_ms=100,
178- # repeat_time_ms=1000,
179- # )
180- # ms = np.median(measurements)
181- # io = (
182- # logits.numel() * logits.element_size()
183- # + samples.numel() * samples.element_size()
184- # )
185- # bandwidth = io * 1e-6 / ms
186- # print(
187- # f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
188- # )
189-
190- # print("---")
191- # print("sampling from logits")
192- # for vocab_size in [128512]:
193- # for batch_size in [1, 16, 32, 64, 128, 256, 512]:
194- # for distrib in [
195- # normal_distribution(1),
196- # normal_distribution(5),
197- # gumbel_distribution(0.1),
198- # gumbel_distribution(1),
199- # ]:
200- # for deterministic in [True, False]:
201- # logits = distrib((batch_size, vocab_size), device="cuda")
202- # samples = torch.zeros(
203- # batch_size, dtype=torch.int32, device=logits.device
204- # )
205- # measurements = bench_gpu_time(
206- # lambda: init_seed_sampling_from_logits(
207- # logits, samples, deterministic=deterministic
208- # ),
209- # dry_run_time_ms=100,
210- # repeat_time_ms=1000,
211- # )
212- # ms = np.median(measurements)
213-
214- # io = (
215- # logits.numel() * logits.element_size()
216- # + samples.numel() * samples.element_size()
217- # )
218- # bandwidth = io * 1e-6 / ms
219- # print(
220- # f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
221- # )
222-
223- # print("---")
224- # print("top-p renorm probs")
225- # for vocab_size in [128512]:
226- # for batch_size in [1, 16, 32, 64, 128, 256, 512]:
227- # torch.manual_seed(42)
228- # for distrib in [
229- # normal_distribution(1),
230- # normal_distribution(5),
231- # gumbel_distribution(0.1),
232- # gumbel_distribution(1),
233- # ]:
234- # for p in [0.1, 0.5, 0.9]:
235- # logits = distrib((batch_size, vocab_size), device="cuda")
236- # probs = torch.softmax(logits, dim=-1)
237- # measurements = bench_gpu_time(
238- # lambda: flashinfer.sampling.top_p_renorm_probs(probs, p),
239- # dry_run_time_ms=100,
240- # repeat_time_ms=1000,
241- # )
242- # ms = np.median(measurements)
243-
244- # io = probs.numel() * probs.element_size() * 2
245- # bandwidth = io * 1e-6 / ms
246- # print(
247- # f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
248- # )
55+ print ("---" )
56+ print ("naive sampling" )
57+ for vocab_size in [128512 ]:
58+ for batch_size in [1 , 16 , 32 , 64 , 128 , 256 , 512 ]:
59+ for distrib in [
60+ normal_distribution (1 ),
61+ normal_distribution (5 ),
62+ gumbel_distribution (0.1 ),
63+ gumbel_distribution (1 ),
64+ ]:
65+ for deterministic in [True , False ]:
66+ logits = distrib ((batch_size , vocab_size ), device = "cuda" )
67+ probs = torch .softmax (logits , dim = - 1 )
68+ samples = torch .zeros (
69+ batch_size , dtype = torch .int32 , device = probs .device
70+ )
71+ measurements = bench_gpu_time (
72+ lambda : init_seed_sampling (probs , deterministic = deterministic ),
73+ dry_run_time_ms = 100 ,
74+ repeat_time_ms = 1000 ,
75+ )
76+ ms = np .median (measurements )
77+
78+ io = (
79+ probs .numel () * probs .element_size ()
80+ + samples .numel () * samples .element_size ()
81+ )
82+ bandwidth = io * 1e-6 / ms
83+ print (
84+ f"vocab_size: { vocab_size } , batch_size: { batch_size } , distrib: { distrib .__name__ } , deterministic: { deterministic } , duration: { ms * 1e3 :.2f} us, effective bandwidth: { bandwidth :.2f} GB/s"
85+ )
86+
87+ print ("---" )
88+ print ("top-k sampling" )
89+ for vocab_size in [128512 ]:
90+ for batch_size in [1 , 16 , 32 , 64 , 128 , 256 , 512 ]:
91+ for distrib in [
92+ normal_distribution (1 ),
93+ normal_distribution (5 ),
94+ gumbel_distribution (0.1 ),
95+ gumbel_distribution (1 ),
96+ ]:
97+ for deterministic in [True , False ]:
98+ for k in [10 , 100 , 1000 , 5000 ]:
99+ logits = distrib ((batch_size , vocab_size ), device = "cuda" )
100+ probs = torch .softmax (logits , dim = - 1 )
101+ samples = torch .zeros (
102+ batch_size , dtype = torch .int32 , device = probs .device
103+ )
104+ measurements = bench_gpu_time (
105+ lambda : init_seed_top_k_sampling (
106+ probs , k , deterministic = deterministic
107+ ),
108+ dry_run_time_ms = 100 ,
109+ repeat_time_ms = 1000 ,
110+ )
111+ ms = np .median (measurements )
112+
113+ io = (
114+ probs .numel () * probs .element_size ()
115+ + samples .numel () * samples .element_size ()
116+ )
117+ bandwidth = io * 1e-6 / ms
118+ print (
119+ f"vocab_size: { vocab_size } , batch_size: { batch_size } , distrib: { distrib .__name__ } , deterministic: { deterministic } , k: { k } , duration: { ms * 1e3 :.2f} us, effective bandwidth: { bandwidth :.2f} GB/s"
120+ )
121+
122+ print ("---" )
123+ print ("top-p sampling" )
124+
125+ for vocab_size in [128512 ]:
126+ for batch_size in [1 , 16 , 32 , 64 , 128 , 256 , 512 ]:
127+ for distrib in [
128+ normal_distribution (1 ),
129+ normal_distribution (5 ),
130+ gumbel_distribution (0.1 ),
131+ gumbel_distribution (1 ),
132+ ]:
133+ for deterministic in [True , False ]:
134+ for p in [0.1 , 0.5 , 0.9 ]:
135+ logits = distrib ((batch_size , vocab_size ), device = "cuda" )
136+ probs = torch .softmax (logits , dim = - 1 )
137+ samples = torch .zeros (
138+ batch_size , dtype = torch .int32 , device = probs .device
139+ )
140+ measurements = bench_gpu_time (
141+ lambda : init_seed_top_p_sampling (
142+ probs , p , deterministic = deterministic
143+ ),
144+ dry_run_time_ms = 100 ,
145+ repeat_time_ms = 1000 ,
146+ )
147+ ms = np .median (measurements )
148+
149+ io = (
150+ probs .numel () * probs .element_size ()
151+ + samples .numel () * samples .element_size ()
152+ )
153+ bandwidth = io * 1e-6 / ms
154+ print (
155+ f"vocab_size: { vocab_size } , batch_size: { batch_size } , distrib: { distrib .__name__ } , deterministic: { deterministic } , p: { p } , duration: { ms * 1e3 :.2f} us, effective bandwidth: { bandwidth :.2f} GB/s"
156+ )
157+
158+ print ("---" )
159+ print ("sampling from softmax(logits)" )
160+ for vocab_size in [128512 ]:
161+ for batch_size in [1 , 16 , 32 , 64 , 128 , 256 , 512 ]:
162+ for distrib in [
163+ normal_distribution (1 ),
164+ normal_distribution (5 ),
165+ gumbel_distribution (0.1 ),
166+ gumbel_distribution (1 ),
167+ ]:
168+ for deterministic in [True , False ]:
169+ logits = distrib ((batch_size , vocab_size ), device = "cuda" )
170+ samples = torch .zeros (
171+ batch_size , dtype = torch .int32 , device = logits .device
172+ )
173+ measurements = bench_gpu_time (
174+ lambda : init_seed_sampling_from_softmax_logits (
175+ logits , samples , deterministic = deterministic
176+ ),
177+ dry_run_time_ms = 100 ,
178+ repeat_time_ms = 1000 ,
179+ )
180+ ms = np .median (measurements )
181+ io = (
182+ logits .numel () * logits .element_size ()
183+ + samples .numel () * samples .element_size ()
184+ )
185+ bandwidth = io * 1e-6 / ms
186+ print (
187+ f"vocab_size: { vocab_size } , batch_size: { batch_size } , distrib: { distrib .__name__ } , deterministic: { deterministic } , duration: { ms * 1e3 :.2f} us, effective bandwidth: { bandwidth :.2f} GB/s"
188+ )
189+
190+ print ("---" )
191+ print ("sampling from logits" )
192+ for vocab_size in [128512 ]:
193+ for batch_size in [1 , 16 , 32 , 64 , 128 , 256 , 512 ]:
194+ for distrib in [
195+ normal_distribution (1 ),
196+ normal_distribution (5 ),
197+ gumbel_distribution (0.1 ),
198+ gumbel_distribution (1 ),
199+ ]:
200+ for deterministic in [True , False ]:
201+ logits = distrib ((batch_size , vocab_size ), device = "cuda" )
202+ samples = torch .zeros (
203+ batch_size , dtype = torch .int32 , device = logits .device
204+ )
205+ measurements = bench_gpu_time (
206+ lambda : init_seed_sampling_from_logits (
207+ logits , samples , deterministic = deterministic
208+ ),
209+ dry_run_time_ms = 100 ,
210+ repeat_time_ms = 1000 ,
211+ )
212+ ms = np .median (measurements )
213+
214+ io = (
215+ logits .numel () * logits .element_size ()
216+ + samples .numel () * samples .element_size ()
217+ )
218+ bandwidth = io * 1e-6 / ms
219+ print (
220+ f"vocab_size: { vocab_size } , batch_size: { batch_size } , distrib: { distrib .__name__ } , deterministic: { deterministic } , duration: { ms * 1e3 :.2f} us, effective bandwidth: { bandwidth :.2f} GB/s"
221+ )
222+
223+ print ("---" )
224+ print ("top-p renorm probs" )
225+ for vocab_size in [128512 ]:
226+ for batch_size in [1 , 16 , 32 , 64 , 128 , 256 , 512 ]:
227+ torch .manual_seed (42 )
228+ for distrib in [
229+ normal_distribution (1 ),
230+ normal_distribution (5 ),
231+ gumbel_distribution (0.1 ),
232+ gumbel_distribution (1 ),
233+ ]:
234+ for p in [0.1 , 0.5 , 0.9 ]:
235+ logits = distrib ((batch_size , vocab_size ), device = "cuda" )
236+ probs = torch .softmax (logits , dim = - 1 )
237+ measurements = bench_gpu_time (
238+ lambda : flashinfer .sampling .top_p_renorm_probs (probs , p ),
239+ dry_run_time_ms = 100 ,
240+ repeat_time_ms = 1000 ,
241+ )
242+ ms = np .median (measurements )
243+
244+ io = probs .numel () * probs .element_size () * 2
245+ bandwidth = io * 1e-6 / ms
246+ print (
247+ f"vocab_size: { vocab_size } , batch_size: { batch_size } , distrib: { distrib .__name__ } , p: { p } , duration: { ms * 1e3 :.2f} us, effective bandwidth: { bandwidth :.2f} GB/s"
248+ )
249249
250250 print ("---" )
251251 print ("top-k renorm probs" )
0 commit comments