Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 150 additions & 101 deletions contrib/big_batch_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import os
from multiprocessing.pool import ThreadPool
import threading
import _thread
from queue import Queue
import traceback
import datetime

import numpy as np
import faiss
Expand Down Expand Up @@ -60,14 +64,21 @@ def toc(self):

def report(self, l):
if self.verbose == 1 or (
l > 1000 and time.time() < self.t_display + 1.0):
self.verbose == 2 and (
l > 1000 and time.time() < self.t_display + 1.0
)
):
return
t = time.time() - self.t0
print(
f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} "
f"[{t:.1f} s] list {l}/{self.index.nlist} "
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
f"wait {self.t_accu[4]:.3f}",
end="\r", flush=True
f"wait {self.t_accu[4]:.3f} "
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
f"mem {faiss.get_mem_usage_kb()}",
end="\r" if self.verbose <= 2 else "\n",
flush=True,
)
self.t_display = time.time()

Expand Down Expand Up @@ -141,24 +152,25 @@ def add_results_to_heap(self, q_subset, D, list_ids, I):
def sizes_in_checkpoint(self):
return (self.xq.shape, self.index.nprobe, self.index.nlist)

def write_checkpoint(self, fname, cur_list_no):
def write_checkpoint(self, fname, completed):
# write to temp file then move to final file
tmpname = fname + ".tmp"
pickle.dump(
{
"sizes": self.sizes_in_checkpoint(),
"cur_list_no": cur_list_no,
"rh": (self.rh.D, self.rh.I),
}, open(tmpname, "wb"), -1
)
with open(tmpname, "wb") as f:
pickle.dump(
{
"sizes": self.sizes_in_checkpoint(),
"completed": completed,
"rh": (self.rh.D, self.rh.I),
}, f, -1)
os.replace(tmpname, fname)

def read_checkpoint(self, fname):
ckp = pickle.load(open(fname, "rb"))
with open(fname, "rb") as f:
ckp = pickle.load(f)
assert ckp["sizes"] == self.sizes_in_checkpoint()
self.rh.D[:] = ckp["rh"][0]
self.rh.I[:] = ckp["rh"][1]
return ckp["cur_list_no"]
return ckp["completed"]


class BlockComputer:
Expand Down Expand Up @@ -225,11 +237,11 @@ def big_batch_search(
verbose=0,
threaded=0,
use_float16=False,
prefetch_threads=8,
computation_threads=0,
prefetch_threads=1,
computation_threads=1,
q_assign=None,
checkpoint=None,
checkpoint_freq=64,
checkpoint_freq=7200,
start_list=0,
end_list=None,
crash_at=-1
Expand All @@ -251,7 +263,7 @@ def big_batch_search(

threaded=0: sequential execution
threaded=1: prefetch next bucket while computing the current one
threaded>1: prefetch this many buckets at a time.
threaded=2: prefetch prefetch_threads buckets at a time.

compute_threads>1: the knn function will get an additional thread_no that
tells which worker should handle this.
Expand Down Expand Up @@ -311,12 +323,13 @@ def big_batch_search(
if end_list is None:
end_list = index.nlist

completed = set()
if checkpoint is not None:
assert (start_list, end_list) == (0, index.nlist)
if os.path.exists(checkpoint):
print("recovering checkpoint", checkpoint)
start_list = bbs.read_checkpoint(checkpoint)
print(" start at list", start_list)
completed = bbs.read_checkpoint(checkpoint)
print(" already completed", len(completed))
else:
print("no checkpoint: starting from scratch")

Expand Down Expand Up @@ -363,94 +376,130 @@ def add_results_and_prefetch(to_add, l):
bbs.add_results_to_heap(*to_add)
pool.close()
else:
# run by batches with parallel prefetch and parallel comp
list_step = threaded
assert start_list % list_step == 0

if prefetch_threads == 0:
prefetch_map = map
else:
prefetch_pool = ThreadPool(prefetch_threads)
prefetch_map = prefetch_pool.map

if computation_threads > 0:
comp_pool = ThreadPool(computation_threads)

def add_results_and_prefetch_batch(to_add, l):
def add_results(to_add):
for ta in to_add: # this one cannot be run in parallel...
if ta is not None:
bbs.add_results_to_heap(*ta)
if prefetch_threads == 0:
add_results(to_add)
else:
add_a = prefetch_pool.apply_async(add_results, (to_add, ))
next_lists = range(l, min(l + list_step, index.nlist))
res = list(prefetch_map(bbs.prepare_bucket, next_lists))
if prefetch_threads > 0:
add_a.get()
return res

# used only when computation_threads > 1
thread_id_to_seq_lock = threading.Lock()
thread_id_to_seq = {}

def do_comp(bucket):
(q_subset, xq_l, list_ids, xb_l) = bucket
def task_manager_thread(
task,
pool_size,
start_task,
end_task,
completed,
output_queue,
input_queue,
):
try:
tid = thread_id_to_seq[threading.get_ident()]
except KeyError:
with thread_id_to_seq_lock:
tid = len(thread_id_to_seq)
thread_id_to_seq[threading.get_ident()] = tid
D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid)
return q_subset, D, list_ids, I

prefetched_buckets = add_results_and_prefetch_batch([], start_list)
to_add = []
pool = ThreadPool(1)
prefetched_buckets_a = None

# loop over inverted lists
for l in range(start_list, end_list, list_step):
bbs.report(l)
buckets = prefetched_buckets
prefetched_buckets_a = pool.apply_async(
add_results_and_prefetch_batch, (to_add, l + list_step))

bbs.start_t_accu()

to_add = []
if computation_threads == 0:
for q_subset, xq_l, list_ids, xb_l in buckets:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
to_add.append((q_subset, D, list_ids, I))
else:
to_add = list(comp_pool.map(do_comp, buckets))

bbs.stop_t_accu(2)
with ThreadPool(pool_size) as pool:
res = [pool.apply_async(
task,
args=(i, output_queue, input_queue))
for i in range(start_task, end_task)
if i not in completed]
for r in res:
r.get()
pool.close()
pool.join()
output_queue.put(None)
except:
traceback.print_exc()
_thread.interrupt_main()
raise

def task_manager(*args):
task_manager = threading.Thread(
target=task_manager_thread,
args=args,
)
task_manager.daemon = True
task_manager.start()
return task_manager

def prepare_task(task_id, output_queue, input_queue=None):
try:
# print(f"Prepare start: {task_id}")
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
# print(f"Prepare end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise

def compute_task(task_id, output_queue, input_queue):
try:
# print(f"Compute start: {task_id}")
t_wait = 0
while True:
t0 = time.time()
input_value = input_queue.get()
t_wait += time.time() - t0
if input_value is None:
# signal for other compute tasks
input_queue.put(None)
break
centroid, q_subset, xq_l, list_ids, xb_l = input_value
# print(f'Compute work start: task {task_id}, centroid {centroid}')
t0 = time.time()
if computation_threads > 1:
D, I = comp.block_search(
xq_l, xb_l, list_ids, k, thread_id=task_id
)
else:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
t_compute = time.time() - t0
# print(f'Compute work end: task {task_id}, centroid {centroid}')
t0 = time.time()
output_queue.put(
(centroid, t_wait, t_compute, q_subset, D, list_ids, I)
)
t_wait = time.time() - t0
# print(f"Compute end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise

prepare_to_compute_queue = Queue(2)
compute_to_main_queue = Queue(2)
compute_task_manager = task_manager(
compute_task,
computation_threads,
0,
computation_threads,
set(),
compute_to_main_queue,
prepare_to_compute_queue,
)
prepare_task_manager = task_manager(
prepare_task,
prefetch_threads,
start_list,
end_list,
completed,
prepare_to_compute_queue,
None,
)

t_checkpoint = time.time()
while True:
value = compute_to_main_queue.get()
if not value:
break
centroid, t_wait, t_compute, q_subset, D, list_ids, I = value
# to test checkpointing
if l == crash_at:
if centroid == crash_at:
1 / 0

bbs.start_t_accu()
prefetched_buckets = prefetched_buckets_a.get()
bbs.stop_t_accu(4)

bbs.t_accu[2] += t_compute
bbs.t_accu[4] += t_wait
bbs.add_results_to_heap(q_subset, D, list_ids, I)
completed.add(centroid)
bbs.report(centroid)
if checkpoint is not None:
if (l // list_step) % checkpoint_freq == 0:
print("writing checkpoint %s" % l)
bbs.write_checkpoint(checkpoint, l)
if time.time() - t_checkpoint > checkpoint_freq:
print("writing checkpoint")
bbs.write_checkpoint(checkpoint, completed)
t_checkpoint = time.time()

# flush add
for ta in to_add:
bbs.add_results_to_heap(*ta)
pool.close()
if prefetch_threads != 0:
prefetch_pool.close()
if computation_threads != 0:
comp_pool.close()
prepare_task_manager.join()
compute_task_manager.join()

bbs.tic("finalize heap")
bbs.rh.finalize()
Expand Down
14 changes: 7 additions & 7 deletions tests/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import platform
import os
import random
import tempfile

from faiss.contrib import datasets
from faiss.contrib import inspect_tools
Expand Down Expand Up @@ -507,7 +508,7 @@ def do_test(self, factory_string, metric=faiss.METRIC_L2):
Dref, Iref = index.search(ds.get_queries(), k)
# faiss.omp_set_num_threads(1)
for method in ("pairwise_distances", "knn_function", "index"):
for threaded in 0, 1, 3, 8:
for threaded in 0, 1, 2:
Dnew, Inew = big_batch_search.big_batch_search(
index, ds.get_queries(),
k, method=method,
Expand Down Expand Up @@ -537,16 +538,15 @@ def test_checkpoint(self):
index.nprobe = 5
Dref, Iref = index.search(ds.get_queries(), k)

r = random.randrange(1 << 60)
checkpoint = "/tmp/test_big_batch_checkpoint.%d" % r
checkpoint = tempfile.mktemp()
try:
# First big batch search
try:
Dnew, Inew = big_batch_search.big_batch_search(
index, ds.get_queries(),
k, method="knn_function",
threaded=4,
checkpoint=checkpoint, checkpoint_freq=4,
threaded=2,
checkpoint=checkpoint, checkpoint_freq=0.1,
crash_at=20
)
except ZeroDivisionError:
Expand All @@ -557,8 +557,8 @@ def test_checkpoint(self):
Dnew, Inew = big_batch_search.big_batch_search(
index, ds.get_queries(),
k, method="knn_function",
threaded=4,
checkpoint=checkpoint, checkpoint_freq=4
threaded=2,
checkpoint=checkpoint, checkpoint_freq=5
)
self.assertLess((Inew != Iref).sum() / Iref.size, 1e-4)
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)
Expand Down