Skip to content

Commit 8c57a0f

Browse files
authored
Add efficiency scripts (#197)
Add efficiency scripts
1 parent 9ef147a commit 8c57a0f

File tree

4 files changed

+400
-0
lines changed

4 files changed

+400
-0
lines changed

efficiency/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This folder contains the code for batch size checks and basic efficiency tests. Every result was spot checked individually, for each model, to ensure proper batch size and runtime estimates. You may check these results on the [efficiency-spotchecks](https://github.com/AnswerDotAI/ModernBERT/tree/efficiency-spotchecks/efficiency) branch.

efficiency/multiprocess_bench.py

Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
import copy
2+
import warnings
3+
warnings.filterwarnings('ignore')
4+
5+
import transformers
6+
import torch
7+
import random
8+
import numpy as np
9+
import time
10+
from transformers import AutoModel, AutoTokenizer
11+
import srsly
12+
import os
13+
import gc
14+
from multiprocessing import Process, Queue
15+
16+
def create_fixed_short_dataset(tokenizer, num_samples=8192):
17+
tokens = torch.randint(100, 16000, (num_samples, 512))
18+
mask = torch.ones(num_samples, 512)
19+
return {
20+
'input_ids': tokens.long(),
21+
'attention_mask': mask.float()
22+
}
23+
24+
def create_fixed_long_dataset(tokenizer, num_samples=8192):
25+
tokens = torch.randint(100, 16000, (num_samples, 8192))
26+
mask = torch.ones(num_samples, 8192)
27+
return {
28+
'input_ids': tokens.long(),
29+
'attention_mask': mask.float()
30+
}
31+
32+
def create_variable_short_dataset(tokenizer, num_samples=8192):
33+
torch.manual_seed(42)
34+
torch.cuda.manual_seed_all(42)
35+
np.random.seed(42)
36+
random.seed(42)
37+
lengths = torch.normal(mean=256, std=64, size=(num_samples,)).int().clamp(16, 512)
38+
tokens_list = []
39+
masks_list = []
40+
for length in lengths:
41+
tokens = torch.randint(100, 16000, (length.item(),))
42+
mask = torch.ones(length.item())
43+
padded_tokens = torch.full((512,), tokenizer.pad_token_id, dtype=torch.long)
44+
padded_mask = torch.zeros(512)
45+
padded_tokens[:length] = tokens
46+
padded_mask[:length] = mask
47+
tokens_list.append(padded_tokens)
48+
masks_list.append(padded_mask)
49+
50+
return {
51+
'input_ids': torch.stack(tokens_list),
52+
'attention_mask': torch.stack(masks_list)
53+
}
54+
55+
def create_variable_long_dataset(tokenizer, num_samples=8192):
56+
torch.manual_seed(42)
57+
torch.cuda.manual_seed_all(42)
58+
np.random.seed(42)
59+
random.seed(42)
60+
lengths = torch.normal(mean=4096, std=1024, size=(num_samples,)).int().clamp(16, 8192)
61+
tokens_list = []
62+
masks_list = []
63+
for length in lengths:
64+
tokens = torch.randint(100, 16000, (length.item(),))
65+
mask = torch.ones(length.item())
66+
padded_tokens = torch.full((8192,), tokenizer.pad_token_id, dtype=torch.long)
67+
padded_mask = torch.zeros(8192)
68+
padded_tokens[:length] = tokens
69+
padded_mask[:length] = mask
70+
tokens_list.append(padded_tokens)
71+
masks_list.append(padded_mask)
72+
73+
return {
74+
'input_ids': torch.stack(tokens_list),
75+
'attention_mask': torch.stack(masks_list)
76+
}
77+
78+
def create_all_datasets(tokenizer, num_samples=8192):
79+
return {
80+
'fixed_short': create_fixed_short_dataset(tokenizer, num_samples),
81+
'variable_short': create_variable_short_dataset(tokenizer, num_samples),
82+
'fixed_long': create_fixed_long_dataset(tokenizer, num_samples),
83+
'variable_long': create_variable_long_dataset(tokenizer, num_samples)
84+
}
85+
86+
def test_batch_size_worker(q, model_name, input_ids, attention_mask, bsize, device, use_xformers):
87+
"""
88+
Worker that:
89+
1. Loads the model
90+
2. Tries given batch size
91+
3. Returns success or fail
92+
"""
93+
try:
94+
if 'gte' in model_name.lower() and use_xformers:
95+
model = AutoModel.from_pretrained(
96+
model_name,
97+
trust_remote_code=True,
98+
local_files_only=False
99+
)
100+
model.config.use_memory_efficient_attention = True
101+
else:
102+
model = AutoModel.from_pretrained(
103+
model_name,
104+
trust_remote_code=True,
105+
local_files_only=False
106+
)
107+
model = model.to(device)
108+
109+
with torch.no_grad():
110+
batch_ids = input_ids[:bsize].to(device)
111+
batch_mask = attention_mask[:bsize].to(device)
112+
model(input_ids=batch_ids, attention_mask=batch_mask)
113+
q.put(('success', True))
114+
except RuntimeError:
115+
q.put(('success', False))
116+
except Exception as e:
117+
q.put(('error', str(e)))
118+
119+
def find_max_batch_size_worker(q, model_name, input_ids, attention_mask, device, use_xformers):
120+
"""
121+
Worker that runs the batch size finding logic.
122+
Each attempt is run in its own worker to ensure full memory isolation.
123+
"""
124+
125+
def try_batch_size(bsize):
126+
print(f"Attempting batch size: {bsize}")
127+
# Spawn a worker for each attempt
128+
attempt_q = Queue()
129+
p = Process(
130+
target=test_batch_size_worker,
131+
args=(attempt_q, model_name, input_ids, attention_mask, bsize, device, use_xformers)
132+
)
133+
p.start()
134+
p.join()
135+
result = attempt_q.get()
136+
p = None
137+
if result[0] == 'error':
138+
# If there's an error unrelated to OOM, raise it
139+
print(f"Error occurred: {result[1]}")
140+
raise RuntimeError(result[1])
141+
success = result[1]
142+
print(f"Batch size {bsize}: {'succeeded' if success else 'failed'}")
143+
144+
print("Clearing CUDA cache and garbage collection")
145+
torch.cuda.empty_cache()
146+
gc.collect()
147+
return success
148+
149+
try:
150+
print("\nStarting batch size search...")
151+
batch_size = 1024
152+
print("\nPhase 1: Increasing batch size until OOM")
153+
# Increase by 16 until OOM or max 4096
154+
while try_batch_size(batch_size) and batch_size < 4096:
155+
batch_size += 16
156+
print(f"Increasing to {batch_size}")
157+
158+
print("\nPhase 2: Backing off by 32 until stable")
159+
# Back off by 32 until stable
160+
while not try_batch_size(batch_size) and batch_size > 64:
161+
batch_size -= 64
162+
print(f"Decreasing to {batch_size}")
163+
164+
# If still not working, try smaller decrements
165+
if not try_batch_size(batch_size):
166+
print("\nPhase 3: Fine-tuning with smaller decrements")
167+
while not try_batch_size(batch_size) and batch_size > 4:
168+
batch_size -= 4
169+
print(f"Fine-tuning decrease to {batch_size}")
170+
if batch_size <= 4 and not try_batch_size(batch_size):
171+
print("Attempting minimum batch size of 1")
172+
batch_size = 1
173+
if not try_batch_size(batch_size):
174+
raise RuntimeError("Cannot find a working batch size.")
175+
176+
print("\nPhase 4: Final optimization")
177+
178+
# Try increments of 32
179+
test_size = batch_size + 32
180+
while test_size < 4096:
181+
success = try_batch_size(test_size)
182+
if not success:
183+
test_size = batch_size
184+
break
185+
batch_size = test_size
186+
test_size += 32
187+
print(f"Testing increment to {test_size}")
188+
189+
# Try increments of 16
190+
test_size = batch_size + 16
191+
while test_size < 4096:
192+
success = try_batch_size(test_size)
193+
if not success:
194+
test_size = batch_size
195+
break
196+
batch_size = test_size
197+
test_size += 16
198+
print(f"Testing increment to {test_size}")
199+
200+
# Try increments of 8
201+
test_size = batch_size + 8
202+
while test_size < 4096:
203+
success = try_batch_size(test_size)
204+
if not success:
205+
test_size = batch_size
206+
break
207+
batch_size = test_size
208+
test_size += 8
209+
print(f"Testing increment to {test_size}")
210+
211+
# Try increments of 4
212+
test_size = batch_size + 4
213+
while test_size < 4096:
214+
success = try_batch_size(test_size)
215+
if not success:
216+
test_size = batch_size
217+
break
218+
batch_size = test_size
219+
test_size += 4
220+
print(f"Testing increment to {test_size}")
221+
222+
# Try increments of 2
223+
test_size = batch_size + 2
224+
while test_size < 4096:
225+
success = try_batch_size(test_size)
226+
if not success:
227+
test_size = batch_size
228+
break
229+
batch_size = test_size
230+
test_size += 2
231+
print(f"Testing increment to {test_size}")
232+
233+
final_batch_size = min(batch_size, 4096)
234+
if final_batch_size > 8:
235+
final_batch_size -= 4
236+
print(f"\nFinal batch size determined: {final_batch_size}")
237+
q.put(('success', final_batch_size))
238+
except Exception as e:
239+
print(f"Error in batch size search: {str(e)}")
240+
q.put(('error', str(e)))
241+
242+
243+
def inference_worker(q, model_name, dataset_name, input_ids, attention_mask, max_batch_size, n_iters, device, use_xformers):
244+
"""
245+
Worker to run inference multiple times and report mean/std of times.
246+
Model loading is done here to isolate memory usage.
247+
"""
248+
try:
249+
if 'gte' in model_name.lower() and use_xformers:
250+
model = AutoModel.from_pretrained(
251+
model_name,
252+
trust_remote_code=True,
253+
local_files_only=False
254+
)
255+
model.config.use_memory_efficient_attention = True
256+
else:
257+
model = AutoModel.from_pretrained(
258+
model_name,
259+
trust_remote_code=True,
260+
local_files_only=False
261+
)
262+
model = model.to(device)
263+
model.eval()
264+
265+
times = []
266+
for _ in range(n_iters):
267+
start_time = time.time()
268+
with torch.no_grad():
269+
for i in range(0, len(input_ids), max_batch_size):
270+
batch_ids = input_ids[i:i+max_batch_size].clone().to(device)
271+
batch_mask = attention_mask[i:i+max_batch_size].clone().to(device)
272+
model(input_ids=batch_ids, attention_mask=batch_mask)
273+
end_time = time.time()
274+
times.append(end_time - start_time)
275+
276+
mean_time = np.mean(times)
277+
std_time = np.std(times)
278+
q.put((dataset_name, mean_time, std_time, max_batch_size))
279+
except Exception as e:
280+
q.put(('error', str(e)))
281+
282+
283+
def run_inference_benchmark(model_name, use_xformers=False, n_iters=10, gpu=0):
284+
device = f'cuda'
285+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, local_files_only=False)
286+
datasets = create_all_datasets(tokenizer, 4096)
287+
288+
processing_times = {}
289+
fixed_batch_sizes = {}
290+
291+
# Ensure a clean GPU state before starting
292+
torch.cuda.empty_cache()
293+
gc.collect()
294+
295+
for dataset_name, dataset in datasets.items():
296+
input_ids = dataset['input_ids']
297+
attention_mask = dataset['attention_mask'].int()
298+
299+
if dataset_name.startswith('fixed_'):
300+
# Run batch size finding in its own worker
301+
q = Queue()
302+
p = Process(
303+
target=find_max_batch_size_worker,
304+
args=(q, model_name, input_ids, attention_mask, device, use_xformers)
305+
)
306+
p.start()
307+
p.join()
308+
result = q.get()
309+
p = None
310+
if result[0] == 'error':
311+
print(f"Error finding batch size for {dataset_name}: {result[1]}")
312+
torch.cuda.empty_cache()
313+
gc.collect()
314+
continue
315+
max_batch_size = result[1]
316+
fixed_batch_sizes[dataset_name] = max_batch_size
317+
else:
318+
# Use batch size from corresponding fixed dataset
319+
fixed_name = 'fixed_' + dataset_name.split('_')[1]
320+
if fixed_name not in fixed_batch_sizes:
321+
print(f"No batch size found for {fixed_name}, skipping {dataset_name}")
322+
torch.cuda.empty_cache()
323+
gc.collect()
324+
continue
325+
max_batch_size = fixed_batch_sizes[fixed_name]
326+
327+
torch.cuda.empty_cache()
328+
gc.collect()
329+
330+
# Run inference in its own worker
331+
q = Queue()
332+
p = Process(
333+
target=inference_worker,
334+
args=(q, model_name, dataset_name, input_ids, attention_mask, max_batch_size, n_iters, device, use_xformers)
335+
)
336+
p.start()
337+
p.join()
338+
result = q.get()
339+
p = None
340+
if result[0] == 'error':
341+
print(f"Error during inference for {dataset_name}: {result[1]}")
342+
torch.cuda.empty_cache()
343+
gc.collect()
344+
continue
345+
346+
dataset_name_ret, mean_time, std_time, bsize = result
347+
processing_times[dataset_name_ret] = {
348+
'mean': mean_time,
349+
'std': std_time,
350+
'max_batch_size': bsize
351+
}
352+
print(f"{dataset_name_ret} -> {mean_time:.2f} ± {std_time:.2f} sec (batch_size: {bsize})")
353+
354+
torch.cuda.empty_cache()
355+
gc.collect()
356+
357+
print("\nProcessing Time Summary:")
358+
print("-" * 50)
359+
print(f"\n{model_name} Model:")
360+
for dataset_name, metrics in processing_times.items():
361+
print(f"{dataset_name}: {metrics['mean']:.2f} ± {metrics['std']:.2f} seconds (batch_size: {metrics['max_batch_size']})")
362+
363+
try:
364+
if use_xformers:
365+
os.makedirs(f"results/{model_name}_xformers", exist_ok=True)
366+
srsly.write_json(f"results/{model_name}_xformers_inference_times.json", processing_times)
367+
else:
368+
os.makedirs(f"results/{model_name}", exist_ok=True)
369+
srsly.write_json(f"results/{model_name}_inference_times.json", processing_times)
370+
except Exception as e:
371+
print(f"Error saving results: {e}")
372+
373+
return processing_times
374+
375+
376+
if __name__ == "__main__":
377+
import argparse
378+
379+
parser = argparse.ArgumentParser(description='Run inference benchmark')
380+
parser.add_argument('--gpu', type=int, default=0, help='GPU number to use')
381+
parser.add_argument('--model', type=str, default="GTE", help='Model name to benchmark')
382+
parser.add_argument('--xformers', action='store_true', help='Use XFormers')
383+
384+
args = parser.parse_args()
385+
processing_times = run_inference_benchmark(model_name=args.model, use_xformers=args.xformers, gpu=args.gpu)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
export CUDA_VISIBLE_DEVICES=0 && python multiprocess_bench.py --model Alibaba-NLP/gte-base-en-v1.5 > gte_inference_times.log 2>&1 &
2+
export CUDA_VISIBLE_DEVICES=1 && python multiprocess_bench.py --model ModernBERT/bert24-base-v2-learning-rate-decay-v2-50B-3-best-and-last-avg > bert24_inference_times.log 2>&1 &
3+
export CUDA_VISIBLE_DEVICES=2 && python multiprocess_bench.py --model bert-base-uncased > bert_inference_times.log 2>&1 &
4+
export CUDA_VISIBLE_DEVICES=3 && python multiprocess_bench.py --model roberta-base > roberta_inference_times.log 2>&1 &
5+
export CUDA_VISIBLE_DEVICES=4 && python multiprocess_bench.py --model microsoft/deberta-v3-base > debertav3_inference_times.log 2>&1 &
6+
export CUDA_VISIBLE_DEVICES=5 && python multiprocess_bench.py --model nomic-ai/nomic-bert-2048 > nomicbert_inference_times.log 2>&1 &
7+
export CUDA_VISIBLE_DEVICES=6 && python multiprocess_bench.py --model Alibaba-NLP/gte-base-en-v1.5 --xformers > gte_xformers_inference_times.log 2>&1 &

0 commit comments

Comments
 (0)