From 1537d18147b535ee16e88de7864bf2743bff696d Mon Sep 17 00:00:00 2001 From: Gergely Szilvasy Date: Wed, 14 Jun 2023 06:33:27 -0700 Subject: [PATCH] bbs producer/consumer threading (#2901) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2901 This diff allows each GPU to work independently, a hot centroid (eg. out-of-distribution queries that hit a centroid heavily) will only block the one GPU that is processing it, others will continue to pick up work independently. Reviewed By: mdouze Differential Revision: D46521298 fbshipit-source-id: e2dda1ace0e1380ae362d56ac193752ee8514107 --- contrib/big_batch_search.py | 251 +++++++++++++++++++++--------------- tests/test_contrib.py | 14 +- 2 files changed, 157 insertions(+), 108 deletions(-) diff --git a/contrib/big_batch_search.py b/contrib/big_batch_search.py index ce769d0f60..6b0fd36e91 100644 --- a/contrib/big_batch_search.py +++ b/contrib/big_batch_search.py @@ -8,6 +8,10 @@ import os from multiprocessing.pool import ThreadPool import threading +import _thread +from queue import Queue +import traceback +import datetime import numpy as np import faiss @@ -60,14 +64,21 @@ def toc(self): def report(self, l): if self.verbose == 1 or ( - l > 1000 and time.time() < self.t_display + 1.0): + self.verbose == 2 and ( + l > 1000 and time.time() < self.t_display + 1.0 + ) + ): return + t = time.time() - self.t0 print( - f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} " + f"[{t:.1f} s] list {l}/{self.index.nlist} " f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} " f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} " - f"wait {self.t_accu[4]:.3f}", - end="\r", flush=True + f"wait {self.t_accu[4]:.3f} " + f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} " + f"mem {faiss.get_mem_usage_kb()}", + end="\r" if self.verbose <= 2 else "\n", + flush=True, ) self.t_display = time.time() @@ -141,24 +152,25 @@ def add_results_to_heap(self, q_subset, D, list_ids, I): def sizes_in_checkpoint(self): return (self.xq.shape, self.index.nprobe, self.index.nlist) - def write_checkpoint(self, fname, cur_list_no): + def write_checkpoint(self, fname, completed): # write to temp file then move to final file tmpname = fname + ".tmp" - pickle.dump( - { - "sizes": self.sizes_in_checkpoint(), - "cur_list_no": cur_list_no, - "rh": (self.rh.D, self.rh.I), - }, open(tmpname, "wb"), -1 - ) + with open(tmpname, "wb") as f: + pickle.dump( + { + "sizes": self.sizes_in_checkpoint(), + "completed": completed, + "rh": (self.rh.D, self.rh.I), + }, f, -1) os.replace(tmpname, fname) def read_checkpoint(self, fname): - ckp = pickle.load(open(fname, "rb")) + with open(fname, "rb") as f: + ckp = pickle.load(f) assert ckp["sizes"] == self.sizes_in_checkpoint() self.rh.D[:] = ckp["rh"][0] self.rh.I[:] = ckp["rh"][1] - return ckp["cur_list_no"] + return ckp["completed"] class BlockComputer: @@ -225,11 +237,11 @@ def big_batch_search( verbose=0, threaded=0, use_float16=False, - prefetch_threads=8, - computation_threads=0, + prefetch_threads=1, + computation_threads=1, q_assign=None, checkpoint=None, - checkpoint_freq=64, + checkpoint_freq=7200, start_list=0, end_list=None, crash_at=-1 @@ -251,7 +263,7 @@ def big_batch_search( threaded=0: sequential execution threaded=1: prefetch next bucket while computing the current one - threaded>1: prefetch this many buckets at a time. + threaded=2: prefetch prefetch_threads buckets at a time. compute_threads>1: the knn function will get an additional thread_no that tells which worker should handle this. @@ -311,12 +323,13 @@ def big_batch_search( if end_list is None: end_list = index.nlist + completed = set() if checkpoint is not None: assert (start_list, end_list) == (0, index.nlist) if os.path.exists(checkpoint): print("recovering checkpoint", checkpoint) - start_list = bbs.read_checkpoint(checkpoint) - print(" start at list", start_list) + completed = bbs.read_checkpoint(checkpoint) + print(" already completed", len(completed)) else: print("no checkpoint: starting from scratch") @@ -363,94 +376,130 @@ def add_results_and_prefetch(to_add, l): bbs.add_results_to_heap(*to_add) pool.close() else: - # run by batches with parallel prefetch and parallel comp - list_step = threaded - assert start_list % list_step == 0 - if prefetch_threads == 0: - prefetch_map = map - else: - prefetch_pool = ThreadPool(prefetch_threads) - prefetch_map = prefetch_pool.map - - if computation_threads > 0: - comp_pool = ThreadPool(computation_threads) - - def add_results_and_prefetch_batch(to_add, l): - def add_results(to_add): - for ta in to_add: # this one cannot be run in parallel... - if ta is not None: - bbs.add_results_to_heap(*ta) - if prefetch_threads == 0: - add_results(to_add) - else: - add_a = prefetch_pool.apply_async(add_results, (to_add, )) - next_lists = range(l, min(l + list_step, index.nlist)) - res = list(prefetch_map(bbs.prepare_bucket, next_lists)) - if prefetch_threads > 0: - add_a.get() - return res - - # used only when computation_threads > 1 - thread_id_to_seq_lock = threading.Lock() - thread_id_to_seq = {} - - def do_comp(bucket): - (q_subset, xq_l, list_ids, xb_l) = bucket + def task_manager_thread( + task, + pool_size, + start_task, + end_task, + completed, + output_queue, + input_queue, + ): try: - tid = thread_id_to_seq[threading.get_ident()] - except KeyError: - with thread_id_to_seq_lock: - tid = len(thread_id_to_seq) - thread_id_to_seq[threading.get_ident()] = tid - D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid) - return q_subset, D, list_ids, I - - prefetched_buckets = add_results_and_prefetch_batch([], start_list) - to_add = [] - pool = ThreadPool(1) - prefetched_buckets_a = None - - # loop over inverted lists - for l in range(start_list, end_list, list_step): - bbs.report(l) - buckets = prefetched_buckets - prefetched_buckets_a = pool.apply_async( - add_results_and_prefetch_batch, (to_add, l + list_step)) - - bbs.start_t_accu() - - to_add = [] - if computation_threads == 0: - for q_subset, xq_l, list_ids, xb_l in buckets: - D, I = comp.block_search(xq_l, xb_l, list_ids, k) - to_add.append((q_subset, D, list_ids, I)) - else: - to_add = list(comp_pool.map(do_comp, buckets)) - - bbs.stop_t_accu(2) + with ThreadPool(pool_size) as pool: + res = [pool.apply_async( + task, + args=(i, output_queue, input_queue)) + for i in range(start_task, end_task) + if i not in completed] + for r in res: + r.get() + pool.close() + pool.join() + output_queue.put(None) + except: + traceback.print_exc() + _thread.interrupt_main() + raise + + def task_manager(*args): + task_manager = threading.Thread( + target=task_manager_thread, + args=args, + ) + task_manager.daemon = True + task_manager.start() + return task_manager + + def prepare_task(task_id, output_queue, input_queue=None): + try: + # print(f"Prepare start: {task_id}") + q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id) + output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l)) + # print(f"Prepare end: {task_id}") + except: + traceback.print_exc() + _thread.interrupt_main() + raise + + def compute_task(task_id, output_queue, input_queue): + try: + # print(f"Compute start: {task_id}") + t_wait = 0 + while True: + t0 = time.time() + input_value = input_queue.get() + t_wait += time.time() - t0 + if input_value is None: + # signal for other compute tasks + input_queue.put(None) + break + centroid, q_subset, xq_l, list_ids, xb_l = input_value + # print(f'Compute work start: task {task_id}, centroid {centroid}') + t0 = time.time() + if computation_threads > 1: + D, I = comp.block_search( + xq_l, xb_l, list_ids, k, thread_id=task_id + ) + else: + D, I = comp.block_search(xq_l, xb_l, list_ids, k) + t_compute = time.time() - t0 + # print(f'Compute work end: task {task_id}, centroid {centroid}') + t0 = time.time() + output_queue.put( + (centroid, t_wait, t_compute, q_subset, D, list_ids, I) + ) + t_wait = time.time() - t0 + # print(f"Compute end: {task_id}") + except: + traceback.print_exc() + _thread.interrupt_main() + raise + + prepare_to_compute_queue = Queue(2) + compute_to_main_queue = Queue(2) + compute_task_manager = task_manager( + compute_task, + computation_threads, + 0, + computation_threads, + set(), + compute_to_main_queue, + prepare_to_compute_queue, + ) + prepare_task_manager = task_manager( + prepare_task, + prefetch_threads, + start_list, + end_list, + completed, + prepare_to_compute_queue, + None, + ) + t_checkpoint = time.time() + while True: + value = compute_to_main_queue.get() + if not value: + break + centroid, t_wait, t_compute, q_subset, D, list_ids, I = value # to test checkpointing - if l == crash_at: + if centroid == crash_at: 1 / 0 - - bbs.start_t_accu() - prefetched_buckets = prefetched_buckets_a.get() - bbs.stop_t_accu(4) - + bbs.t_accu[2] += t_compute + bbs.t_accu[4] += t_wait + bbs.add_results_to_heap(q_subset, D, list_ids, I) + completed.add(centroid) + bbs.report(centroid) if checkpoint is not None: - if (l // list_step) % checkpoint_freq == 0: - print("writing checkpoint %s" % l) - bbs.write_checkpoint(checkpoint, l) + if time.time() - t_checkpoint > checkpoint_freq: + print("writing checkpoint") + bbs.write_checkpoint(checkpoint, completed) + t_checkpoint = time.time() - # flush add - for ta in to_add: - bbs.add_results_to_heap(*ta) - pool.close() - if prefetch_threads != 0: - prefetch_pool.close() - if computation_threads != 0: - comp_pool.close() + prepare_task_manager.join() + compute_task_manager.join() bbs.tic("finalize heap") bbs.rh.finalize() diff --git a/tests/test_contrib.py b/tests/test_contrib.py index cfeee6397b..61a5023ddd 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -9,6 +9,7 @@ import platform import os import random +import tempfile from faiss.contrib import datasets from faiss.contrib import inspect_tools @@ -507,7 +508,7 @@ def do_test(self, factory_string, metric=faiss.METRIC_L2): Dref, Iref = index.search(ds.get_queries(), k) # faiss.omp_set_num_threads(1) for method in ("pairwise_distances", "knn_function", "index"): - for threaded in 0, 1, 3, 8: + for threaded in 0, 1, 2: Dnew, Inew = big_batch_search.big_batch_search( index, ds.get_queries(), k, method=method, @@ -537,16 +538,15 @@ def test_checkpoint(self): index.nprobe = 5 Dref, Iref = index.search(ds.get_queries(), k) - r = random.randrange(1 << 60) - checkpoint = "/tmp/test_big_batch_checkpoint.%d" % r + checkpoint = tempfile.mktemp() try: # First big batch search try: Dnew, Inew = big_batch_search.big_batch_search( index, ds.get_queries(), k, method="knn_function", - threaded=4, - checkpoint=checkpoint, checkpoint_freq=4, + threaded=2, + checkpoint=checkpoint, checkpoint_freq=0.1, crash_at=20 ) except ZeroDivisionError: @@ -557,8 +557,8 @@ def test_checkpoint(self): Dnew, Inew = big_batch_search.big_batch_search( index, ds.get_queries(), k, method="knn_function", - threaded=4, - checkpoint=checkpoint, checkpoint_freq=4 + threaded=2, + checkpoint=checkpoint, checkpoint_freq=5 ) self.assertLess((Inew != Iref).sum() / Iref.size, 1e-4) np.testing.assert_almost_equal(Dnew, Dref, decimal=4)