diff --git a/contrib/torch_utils.py b/contrib/torch_utils.py index 790c295e48..e371932c9f 100644 --- a/contrib/torch_utils.py +++ b/contrib/torch_utils.py @@ -33,7 +33,7 @@ def swig_ptr_from_UInt8Tensor(x): assert x.is_contiguous() assert x.dtype == torch.uint8 return faiss.cast_integer_to_uint8_ptr( - x.storage().data_ptr() + x.storage_offset()) + x.untyped_storage().data_ptr() + x.storage_offset()) def swig_ptr_from_HalfTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ @@ -41,28 +41,28 @@ def swig_ptr_from_HalfTensor(x): assert x.dtype == torch.float16 # no canonical half type in C/C++ return faiss.cast_integer_to_void_ptr( - x.storage().data_ptr() + x.storage_offset() * 2) + x.untyped_storage().data_ptr() + x.storage_offset() * 2) def swig_ptr_from_FloatTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.float32 return faiss.cast_integer_to_float_ptr( - x.storage().data_ptr() + x.storage_offset() * 4) + x.untyped_storage().data_ptr() + x.storage_offset() * 4) def swig_ptr_from_IntTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.int32, 'dtype=%s' % x.dtype return faiss.cast_integer_to_int_ptr( - x.storage().data_ptr() + x.storage_offset() * 4) + x.untyped_storage().data_ptr() + x.storage_offset() * 4) def swig_ptr_from_IndicesTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.int64, 'dtype=%s' % x.dtype return faiss.cast_integer_to_idx_t_ptr( - x.storage().data_ptr() + x.storage_offset() * 8) + x.untyped_storage().data_ptr() + x.storage_offset() * 8) @contextlib.contextmanager def using_stream(res, pytorch_stream=None):