-
Notifications
You must be signed in to change notification settings - Fork 197
Python API for CAGRA+HNSW #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
09eb3ab
c api and tests
divyegala 8bc035e
remove unneeded comment
divyegala f8327f5
Update ann_hnsw_c.cu
divyegala 6c0df11
Update ann_hnsw_c.cu
divyegala 8860a09
rename test
divyegala 081eba5
passing python tests
divyegala 5ba4fad
documentation
divyegala aa6c057
Merge branch 'branch-24.10' into hnsw-python-api
divyegala 0c3d053
more docs
divyegala b47f92f
merging upstream
divyegala 53bcf5d
Merge branch 'branch-24.10' into hnsw-python-api
cjnolet 0c2d082
passing tests
divyegala 360ac24
Merge branch 'branch-24.10' into hnsw-python-api
cjnolet 8c53e22
address review
divyegala 09de6d3
fix merge conflicts
divyegala ef98a4e
address review
divyegala 97215f2
revert some changes
divyegala 4acd22b
fix failing tests
divyegala 6f86848
Merge branch 'branch-24.10' into hnsw-python-api
cjnolet 006e77c
add some stream syncs in nn_descent
divyegala 4d36c80
Merge branch 'branch-24.10' into hnsw-python-api
divyegala 8d4d1a2
add more syncs, use thrust_policy
divyegala 0409d12
Revert "add some stream syncs in nn_descent"
divyegala 366af06
Revert "add more syncs, use thrust_policy"
divyegala ad40942
1000 rows in test
divyegala 6460d2a
Merge remote-tracking branch 'upstream/branch-24.10' into hnsw-python…
divyegala File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -267,6 +267,15 @@ cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index); | |
| */ | ||
| cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index); | ||
|
|
||
| /** | ||
| * @brief Get dimension of the CAGRA index | ||
| * | ||
| * @param[in] index CAGRA index | ||
| * @param[out] dim return dimension of the index | ||
| * @return cuvsError_t | ||
| */ | ||
| cuvsError_t cuvsCagraIndexDim(cuvsCagraIndex_t index, int* dim); | ||
|
|
||
| /** | ||
| * @} | ||
| */ | ||
|
|
@@ -337,7 +346,10 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res, | |
| * It is also important to note that the CAGRA Index must have been built | ||
| * with the same type of `queries`, such that `index.dtype.code == | ||
| * queries.dl_tensor.dtype.code` Types for input are: | ||
| * 1. `queries`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` | ||
| * 1. `queries`: | ||
| * a. kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` | ||
| * b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8` | ||
| * c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8` | ||
| * 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32` | ||
| * 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` | ||
| * | ||
|
|
@@ -394,7 +406,7 @@ cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
| * | ||
| * Experimental, both the API and the serialization format are subject to change. | ||
| * | ||
| * @code{.cpp} | ||
| * @code{.c} | ||
| * #include <cuvs/neighbors/cagra.h> | ||
| * | ||
| * // Create cuvsResources_t | ||
|
|
@@ -416,6 +428,32 @@ cuvsError_t cuvsCagraSerialize(cuvsResources_t res, | |
| cuvsCagraIndex_t index, | ||
| bool include_dataset); | ||
|
|
||
| /** | ||
| * Save the CAGRA index to file in hnswlib format. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a note here that this must be loaded by the (patched) hnswlib wrappers inside cuVS and it can't just be loaded with hnswlib? |
||
| * | ||
| * Experimental, both the API and the serialization format are subject to change. | ||
| * | ||
| * @code{.c} | ||
| * #include <cuvs/core/c_api.h> | ||
| * #include <cuvs/neighbors/cagra.h> | ||
| * | ||
| * // Create cuvsResources_t | ||
| * cuvsResources_t res; | ||
| * cuvsError_t res_create_status = cuvsResourcesCreate(&res); | ||
| * | ||
| * // create an index with `cuvsCagraBuild` | ||
| * cuvsCagraSerializeHnswlib(res, "/path/to/index", index); | ||
| * @endcode | ||
| * | ||
| * @param[in] res cuvsResources_t opaque C handle | ||
| * @param[in] filename the file name for saving the index | ||
| * @param[in] index CAGRA index | ||
| * | ||
| */ | ||
| cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res, | ||
| const char* filename, | ||
| cuvsCagraIndex_t index); | ||
|
|
||
| /** | ||
| * Load index from file. | ||
| * | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,206 @@ | ||
| /* | ||
| * Copyright (c) 2024, NVIDIA CORPORATION. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuvs/core/c_api.h> | ||
| #include <cuvs/distance/distance.h> | ||
| #include <dlpack/dlpack.h> | ||
| #include <stdbool.h> | ||
| #include <stdint.h> | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /** | ||
| * @defgroup hnsw_c_search_params C API for hnswlib wrapper search params | ||
| * @{ | ||
| */ | ||
|
|
||
| struct cuvsHnswSearchParams { | ||
| int32_t ef; | ||
| int32_t num_threads; | ||
| }; | ||
|
|
||
| typedef struct cuvsHnswSearchParams* cuvsHnswSearchParams_t; | ||
|
|
||
| /** | ||
| * @brief Allocate HNSW search params, and populate with default values | ||
| * | ||
| * @param[in] params cuvsHnswSearchParams_t to allocate | ||
| * @return cuvsError_t | ||
| */ | ||
| cuvsError_t cuvsHnswSearchParamsCreate(cuvsHnswSearchParams_t* params); | ||
|
|
||
| /** | ||
| * @brief De-allocate HNSW search params | ||
| * | ||
| * @param[in] params cuvsHnswSearchParams_t to de-allocate | ||
| * @return cuvsError_t | ||
| */ | ||
| cuvsError_t cuvsHnswSearchParamsDestroy(cuvsHnswSearchParams_t params); | ||
|
|
||
| /** | ||
| * @} | ||
| */ | ||
|
|
||
| /** | ||
| * @defgroup hnsw_c_index C API for hnswlib wrapper index | ||
| * @{ | ||
| */ | ||
|
|
||
| /** | ||
| * @brief Struct to hold address of cuvs::neighbors::Hnsw::index and its active trained dtype | ||
| * | ||
| */ | ||
| typedef struct { | ||
| uintptr_t addr; | ||
| DLDataType dtype; | ||
|
|
||
| } cuvsHnswIndex; | ||
|
|
||
| typedef cuvsHnswIndex* cuvsHnswIndex_t; | ||
|
|
||
| /** | ||
| * @brief Allocate HNSW index | ||
| * | ||
| * @param[in] index cuvsHnswIndex_t to allocate | ||
| * @return HnswError_t | ||
| */ | ||
| cuvsError_t cuvsHnswIndexCreate(cuvsHnswIndex_t* index); | ||
|
|
||
| /** | ||
| * @brief De-allocate HNSW index | ||
| * | ||
| * @param[in] index cuvsHnswIndex_t to de-allocate | ||
| */ | ||
| cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index); | ||
|
|
||
| /** | ||
| * @} | ||
| */ | ||
|
|
||
| /** | ||
| * @defgroup hnsw_c_index_search C API for CUDA ANN Graph-based nearest neighbor search | ||
| * @{ | ||
| */ | ||
| /** | ||
| * @brief Search a HNSW index with a `DLManagedTensor` which has underlying | ||
| * `DLDeviceType` equal to `kDLCPU`, `kDLCUDAHost`, or `kDLCUDAManaged`. | ||
| * It is also important to note that the HNSW Index must have been built | ||
| * with the same type of `queries`, such that `index.dtype.code == | ||
| * queries.dl_tensor.dtype.code` | ||
| * Supported types for input are: | ||
| * 1. `queries`: | ||
| * a. kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` | ||
| * b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8` | ||
| * c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8` | ||
| * 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 64` | ||
| * 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` | ||
| * | ||
| * @code {.c} | ||
| * #include <cuvs/core/c_api.h> | ||
| * #include <cuvs/neighbors/hnsw.h> | ||
| * | ||
| * // Create cuvsResources_t | ||
| * cuvsResources_t res; | ||
| * cuvsError_t res_create_status = cuvsResourcesCreate(&res); | ||
| * | ||
| * // Assume a populated `DLManagedTensor` type here | ||
| * DLManagedTensor dataset; | ||
| * DLManagedTensor queries; | ||
| * DLManagedTensor neighbors; | ||
| * | ||
| * // Create default search params | ||
| * cuvsHnswSearchParams_t params; | ||
| * cuvsError_t params_create_status = cuvsHnswSearchParamsCreate(¶ms); | ||
| * | ||
| * // Search the `index` built using `cuvsHnswBuild` | ||
| * cuvsError_t search_status = cuvsHnswSearch(res, params, index, &queries, &neighbors, | ||
| * &distances); | ||
| * | ||
| * // de-allocate `params` and `res` | ||
| * cuvsError_t params_destroy_status = cuvsHnswSearchParamsDestroy(params); | ||
| * cuvsError_t res_destroy_status = cuvsResourcesDestroy(res); | ||
| * @endcode | ||
| * | ||
| * @param[in] res cuvsResources_t opaque C handle | ||
| * @param[in] params cuvsHnswSearchParams_t used to search Hnsw index | ||
| * @param[in] index cuvsHnswIndex which has been returned by `cuvsHnswBuild` | ||
| * @param[in] queries DLManagedTensor* queries dataset to search | ||
| * @param[out] neighbors DLManagedTensor* output `k` neighbors for queries | ||
| * @param[out] distances DLManagedTensor* output `k` distances for queries | ||
| */ | ||
| cuvsError_t cuvsHnswSearch(cuvsResources_t res, | ||
| cuvsHnswSearchParams_t params, | ||
| cuvsHnswIndex_t index, | ||
| DLManagedTensor* queries, | ||
| DLManagedTensor* neighbors, | ||
| DLManagedTensor* distances); | ||
|
|
||
| /** | ||
| * @} | ||
| */ | ||
|
|
||
| /** | ||
| * @defgroup hnsw_c_serialize HNSW C-API serialize functions | ||
| * @{ | ||
| */ | ||
|
|
||
| /** | ||
| * Load hnswlib index from file which was serialized from a HNSW index. | ||
| * | ||
| * Experimental, both the API and the serialization format are subject to change. | ||
| * | ||
| * @code{.c} | ||
| * #include <cuvs/core/c_api.h> | ||
| * #include <cuvs/neighbors/cagra.h> | ||
| * #include <cuvs/neighbors/hnsw.h> | ||
| * | ||
| * // Create cuvsResources_t | ||
| * cuvsResources_t res; | ||
| * cuvsError_t res_create_status = cuvsResourcesCreate(&res); | ||
| * | ||
| * // create an index with `cuvsCagraBuild` | ||
| * cuvsCagraSerializeHnswlib(res, "/path/to/index", index); | ||
| * | ||
| * // Load the serialized CAGRA index from file as an hnswlib index | ||
| * // The index should have the same dtype as the one used to build CAGRA the index | ||
| * cuvsHnswIndex_t hnsw_index; | ||
| * cuvsHnswIndexCreate(&hnsw_index); | ||
| * hnsw_index->dtype = index->dtype; | ||
| * cuvsCagraDeserialize(res, "/path/to/index", hnsw_index); | ||
| * @endcode | ||
| * | ||
| * @param[in] res cuvsResources_t opaque C handle | ||
| * @param[in] filename the name of the file that stores the index | ||
| * @param[in] dim dimensions of the training dataset | ||
| * @param[in] metric distance metric to search. Supported metrics ("L2Expanded") | ||
| * @param[out] index HNSW index loaded disk | ||
| */ | ||
| cuvsError_t cuvsHnswDeserialize(cuvsResources_t res, | ||
| const char* filename, | ||
| int dim, | ||
| cuvsDistanceType metric, | ||
| cuvsHnswIndex_t index); | ||
| /** | ||
| * @} | ||
| */ | ||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.