diff --git a/rust/lance-index/src/vector/pq/distance.rs b/rust/lance-index/src/vector/pq/distance.rs index a0124012f67..9c6c25bcfc2 100644 --- a/rust/lance-index/src/vector/pq/distance.rs +++ b/rust/lance-index/src/vector/pq/distance.rs @@ -4,13 +4,11 @@ use core::panic; use std::cmp::{max, min}; +use super::{num_centroids, utils::get_sub_vector_centroids}; use lance_core::assume_eq; use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2}; use lance_linalg::simd::u8::u8x16; use lance_linalg::simd::{Shuffle, SIMD}; -use lance_table::utils::LanceIteratorExtension; - -use super::{num_centroids, utils::get_sub_vector_centroids}; // for quantizing the distance table, we need to know the max possible distance, // so we perform a flat search on the first `FLAT_NUM_4BIT_PQ` rows. @@ -43,16 +41,17 @@ pub fn build_distance_table_l2_impl( let dimension = query.len(); let sub_vector_length = dimension / num_sub_vectors; let num_centroids = 2_usize.pow(NUM_BITS); - query - .chunks_exact(sub_vector_length) - .enumerate() - .flat_map(|(i, sub_vec)| { - let subvec_centroids = - get_sub_vector_centroids::(codebook, dimension, num_sub_vectors, i); - l2_distance_batch(sub_vec, subvec_centroids, sub_vector_length) - }) - .exact_size(num_sub_vectors * num_centroids) - .collect() + let mut result = Vec::with_capacity(num_sub_vectors * num_centroids); + for (i, sub_vec) in query.chunks_exact(sub_vector_length).enumerate() { + let subvec_centroids = + get_sub_vector_centroids::(codebook, dimension, num_sub_vectors, i); + result.extend(l2_distance_batch( + sub_vec, + subvec_centroids, + sub_vector_length, + )); + } + result } /// Build a Distance Table from the query to each PQ centroid @@ -79,16 +78,17 @@ pub fn build_distance_table_dot_impl( let dimension = query.len(); let sub_vector_length = dimension / num_sub_vectors; let num_centroids = 2_usize.pow(NUM_BITS); - query - .chunks_exact(sub_vector_length) - .enumerate() - .flat_map(|(i, sub_vec)| { - let subvec_centroids = - get_sub_vector_centroids::(codebook, dimension, num_sub_vectors, i); - dot_distance_batch(sub_vec, subvec_centroids, sub_vector_length) - }) - .exact_size(num_sub_vectors * num_centroids) - .collect() + let mut result = Vec::with_capacity(num_sub_vectors * num_centroids); + for (i, sub_vec) in query.chunks_exact(sub_vector_length).enumerate() { + let subvec_centroids = + get_sub_vector_centroids::(codebook, dimension, num_sub_vectors, i); + result.extend(dot_distance_batch( + sub_vec, + subvec_centroids, + sub_vector_length, + )); + } + result } /// Compute L2 distance from the query to all code.