Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion contrib/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,9 @@ def torch_replacement_sa_decode(self, codes, x=None):
if issubclass(the_class, faiss.Index):
handle_torch_Index(the_class)


# allows torch tensor usage with bfKnn
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1):
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False):
if type(xb) is np.ndarray:
# Forward to faiss __init__.py base method
return faiss.knn_gpu_numpy(res, xq, xb, k, D, I, metric, device)
Expand Down Expand Up @@ -574,6 +575,7 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
args.outIndices = I_ptr
args.outIndicesType = I_type
args.device = device
args.use_raft = use_raft

with using_stream(res):
faiss.bfKnn(res, args)
Expand Down
20 changes: 13 additions & 7 deletions faiss/gpu/test/torch_test_contrib_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_sa_encode_decode(self):
return

class TestTorchUtilsKnnGpu(unittest.TestCase):
def test_knn_gpu(self):
def test_knn_gpu(self, use_raft=False):
torch.manual_seed(10)
d = 32
nb = 1024
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_knn_gpu(self):
else:
xb_c = xb_np

D, I = faiss.knn_gpu(res, xq_c, xb_c, k)
D, I = faiss.knn_gpu(res, xq_c, xb_c, k, use_raft=use_raft)

self.assertTrue(torch.equal(torch.from_numpy(I), gt_I))
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1e-4)
Expand All @@ -312,15 +312,15 @@ def test_knn_gpu(self):
xb_c = to_column_major_torch(xb)
assert not xb_c.is_contiguous()

D, I = faiss.knn_gpu(res, xq_c, xb_c, k)
D, I = faiss.knn_gpu(res, xq_c, xb_c, k, use_raft=use_raft)

self.assertTrue(torch.equal(I.cpu(), gt_I))
self.assertLess((D.cpu() - gt_D).abs().max(), 1e-4)

# test on subset
try:
# This internally uses the current pytorch stream
D, I = faiss.knn_gpu(res, xq_c[6:8], xb_c, k)
D, I = faiss.knn_gpu(res, xq_c[6:8], xb_c, k, use_raft=use_raft)
except TypeError:
if not xq_row_major:
# then it is expected
Expand All @@ -331,7 +331,13 @@ def test_knn_gpu(self):
self.assertTrue(torch.equal(I.cpu(), gt_I[6:8]))
self.assertLess((D.cpu() - gt_D[6:8]).abs().max(), 1e-4)

def test_knn_gpu_datatypes(self):
@unittest.skipUnless(
"RAFT" in faiss.get_compile_options(),
"only if RAFT is compiled in")
def test_knn_gpu_raft(self):
self.test_knn_gpu(use_raft=True)

def test_knn_gpu_datatypes(self, use_raft=False):
torch.manual_seed(10)
d = 10
nb = 1024
Expand All @@ -354,7 +360,7 @@ def test_knn_gpu_datatypes(self):
D = torch.zeros(nq, k, device=xb_c.device, dtype=torch.float32)
I = torch.zeros(nq, k, device=xb_c.device, dtype=torch.int32)

faiss.knn_gpu(res, xq_c, xb_c, k, D, I)
faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_raft=use_raft)

self.assertTrue(torch.equal(I.long().cpu(), gt_I))
self.assertLess((D.float().cpu() - gt_D).abs().max(), 1.5e-3)
Expand All @@ -366,7 +372,7 @@ def test_knn_gpu_datatypes(self):
xb_c = xb.half().numpy()
xq_c = xq.half().numpy()

faiss.knn_gpu(res, xq_c, xb_c, k, D, I)
faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_raft=use_raft)

self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I))
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3)
Expand Down