diff --git a/contrib/torch_utils.py b/contrib/torch_utils.py index e371932c9f..18f136e914 100644 --- a/contrib/torch_utils.py +++ b/contrib/torch_utils.py @@ -492,8 +492,9 @@ def torch_replacement_sa_decode(self, codes, x=None): if issubclass(the_class, faiss.Index): handle_torch_Index(the_class) + # allows torch tensor usage with bfKnn -def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1): +def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False): if type(xb) is np.ndarray: # Forward to faiss __init__.py base method return faiss.knn_gpu_numpy(res, xq, xb, k, D, I, metric, device) @@ -574,6 +575,7 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI args.outIndices = I_ptr args.outIndicesType = I_type args.device = device + args.use_raft = use_raft with using_stream(res): faiss.bfKnn(res, args) diff --git a/faiss/gpu/test/torch_test_contrib_gpu.py b/faiss/gpu/test/torch_test_contrib_gpu.py index 0c949c29f2..f7444337f1 100644 --- a/faiss/gpu/test/torch_test_contrib_gpu.py +++ b/faiss/gpu/test/torch_test_contrib_gpu.py @@ -249,7 +249,7 @@ def test_sa_encode_decode(self): return class TestTorchUtilsKnnGpu(unittest.TestCase): - def test_knn_gpu(self): + def test_knn_gpu(self, use_raft=False): torch.manual_seed(10) d = 32 nb = 1024 @@ -286,7 +286,7 @@ def test_knn_gpu(self): else: xb_c = xb_np - D, I = faiss.knn_gpu(res, xq_c, xb_c, k) + D, I = faiss.knn_gpu(res, xq_c, xb_c, k, use_raft=use_raft) self.assertTrue(torch.equal(torch.from_numpy(I), gt_I)) self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1e-4) @@ -312,7 +312,7 @@ def test_knn_gpu(self): xb_c = to_column_major_torch(xb) assert not xb_c.is_contiguous() - D, I = faiss.knn_gpu(res, xq_c, xb_c, k) + D, I = faiss.knn_gpu(res, xq_c, xb_c, k, use_raft=use_raft) self.assertTrue(torch.equal(I.cpu(), gt_I)) self.assertLess((D.cpu() - gt_D).abs().max(), 1e-4) @@ -320,7 +320,7 @@ def test_knn_gpu(self): # test on subset try: # This internally uses the current pytorch stream - D, I = faiss.knn_gpu(res, xq_c[6:8], xb_c, k) + D, I = faiss.knn_gpu(res, xq_c[6:8], xb_c, k, use_raft=use_raft) except TypeError: if not xq_row_major: # then it is expected @@ -331,7 +331,13 @@ def test_knn_gpu(self): self.assertTrue(torch.equal(I.cpu(), gt_I[6:8])) self.assertLess((D.cpu() - gt_D[6:8]).abs().max(), 1e-4) - def test_knn_gpu_datatypes(self): + @unittest.skipUnless( + "RAFT" in faiss.get_compile_options(), + "only if RAFT is compiled in") + def test_knn_gpu_raft(self): + self.test_knn_gpu(use_raft=True) + + def test_knn_gpu_datatypes(self, use_raft=False): torch.manual_seed(10) d = 10 nb = 1024 @@ -354,7 +360,7 @@ def test_knn_gpu_datatypes(self): D = torch.zeros(nq, k, device=xb_c.device, dtype=torch.float32) I = torch.zeros(nq, k, device=xb_c.device, dtype=torch.int32) - faiss.knn_gpu(res, xq_c, xb_c, k, D, I) + faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_raft=use_raft) self.assertTrue(torch.equal(I.long().cpu(), gt_I)) self.assertLess((D.float().cpu() - gt_D).abs().max(), 1.5e-3) @@ -366,7 +372,7 @@ def test_knn_gpu_datatypes(self): xb_c = xb.half().numpy() xq_c = xq.half().numpy() - faiss.knn_gpu(res, xq_c, xb_c, k, D, I) + faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_raft=use_raft) self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I)) self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3)