Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions rust/lance-index/src/vector/pq/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
use core::panic;
use std::cmp::{max, min};

use super::{num_centroids, utils::get_sub_vector_centroids};
use lance_core::assume_eq;
use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2};
use lance_linalg::simd::u8::u8x16;
use lance_linalg::simd::{Shuffle, SIMD};
use lance_table::utils::LanceIteratorExtension;

use super::{num_centroids, utils::get_sub_vector_centroids};

// for quantizing the distance table, we need to know the max possible distance,
// so we perform a flat search on the first `FLAT_NUM_4BIT_PQ` rows.
Expand Down Expand Up @@ -43,16 +41,17 @@ pub fn build_distance_table_l2_impl<const NUM_BITS: u32, T: L2>(
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
let num_centroids = 2_usize.pow(NUM_BITS);
query
.chunks_exact(sub_vector_length)
.enumerate()
.flat_map(|(i, sub_vec)| {
let subvec_centroids =
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
l2_distance_batch(sub_vec, subvec_centroids, sub_vector_length)
})
.exact_size(num_sub_vectors * num_centroids)
.collect()
let mut result = Vec::with_capacity(num_sub_vectors * num_centroids);
for (i, sub_vec) in query.chunks_exact(sub_vector_length).enumerate() {
let subvec_centroids =
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
result.extend(l2_distance_batch(
sub_vec,
subvec_centroids,
sub_vector_length,
));
}
result
}

/// Build a Distance Table from the query to each PQ centroid
Expand All @@ -79,16 +78,17 @@ pub fn build_distance_table_dot_impl<const NUM_BITS: u32, T: Dot>(
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
let num_centroids = 2_usize.pow(NUM_BITS);
query
.chunks_exact(sub_vector_length)
.enumerate()
.flat_map(|(i, sub_vec)| {
let subvec_centroids =
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
dot_distance_batch(sub_vec, subvec_centroids, sub_vector_length)
})
.exact_size(num_sub_vectors * num_centroids)
.collect()
let mut result = Vec::with_capacity(num_sub_vectors * num_centroids);
for (i, sub_vec) in query.chunks_exact(sub_vector_length).enumerate() {
let subvec_centroids =
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
result.extend(dot_distance_batch(
sub_vec,
subvec_centroids,
sub_vector_length,
));
}
result
}

/// Compute L2 distance from the query to all code.
Expand Down