From 2c5e9e6092f06bd8b4a0c618c9714e39e4b315dc Mon Sep 17 00:00:00 2001 From: koide3 <31344317+koide3@users.noreply.github.com> Date: Sun, 12 Jan 2025 13:30:09 +0900 Subject: [PATCH] improve batch_knn_search performance (#101) --- src/python/kdtree.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/python/kdtree.cpp b/src/python/kdtree.cpp index 3d9d8c0..5003935 100644 --- a/src/python/kdtree.cpp +++ b/src/python/kdtree.cpp @@ -129,7 +129,7 @@ void define_kdtree(py::module& m) { std::vector k_indices(pts.rows(), -1); std::vector k_sq_dists(pts.rows(), std::numeric_limits::max()); -#pragma omp parallel for num_threads(num_threads) +#pragma omp parallel for num_threads(num_threads) schedule(guided, 4) for (int i = 0; i < pts.rows(); ++i) { const size_t found = traits::nearest_neighbor_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), &k_indices[i], &k_sq_dists[i]); if (!found) { @@ -154,9 +154,9 @@ void define_kdtree(py::module& m) { Returns ------- - k_indices : numpy.ndarray, shape (n,) + k_indices : numpy.ndarray, shape (n, k) The indices of the nearest neighbors for each input point. If a neighbor was not found, the index is -1. - k_sq_dists : numpy.ndarray, shape (n,) + k_sq_dists : numpy.ndarray, shape (n, k) The squared distances to the nearest neighbors for each input point. )""") @@ -167,16 +167,21 @@ void define_kdtree(py::module& m) { throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)"); } - std::vector> k_indices(pts.rows(), std::vector(k, -1)); - std::vector> k_sq_dists(pts.rows(), std::vector(k, std::numeric_limits::max())); + Eigen::Matrix k_indices(pts.rows(), k); + Eigen::Matrix k_sq_dists(pts.rows(), k); + k_indices.setConstant(-1); + k_sq_dists.setConstant(std::numeric_limits::max()); -#pragma omp parallel for num_threads(num_threads) +#pragma omp parallel for num_threads(num_threads) schedule(guided, 4) for (int i = 0; i < pts.rows(); ++i) { - const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices[i].data(), k_sq_dists[i].data()); + size_t* k_indices_begin = k_indices.data() + i * k; + double* k_sq_dists_begin = k_sq_dists.data() + i * k; + + const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices_begin, k_sq_dists_begin); if (found < k) { for (size_t j = found; j < k; ++j) { - k_indices[i][j] = -1; - k_sq_dists[i][j] = std::numeric_limits::max(); + k_indices_begin[j] = -1; + k_sq_dists_begin[j] = std::numeric_limits::max(); } } }