diff --git a/rust/lance-index/src/vector/flat.rs b/rust/lance-index/src/vector/flat.rs index 296a747136f..65a305d9c37 100644 --- a/rust/lance-index/src/vector/flat.rs +++ b/rust/lance-index/src/vector/flat.rs @@ -143,6 +143,5 @@ pub async fn compute_distance( location: location!(), }) }) - .await - .unwrap() + .await? } diff --git a/rust/lance-linalg/src/distance.rs b/rust/lance-linalg/src/distance.rs index 6e79c7d8b03..84c81fe85ed 100644 --- a/rust/lance-linalg/src/distance.rs +++ b/rust/lance-linalg/src/distance.rs @@ -128,12 +128,18 @@ pub fn multivec_distance( } } - let dists = vectors - .iter() - .map(|v| { - v.map(|v| { + let mut dists = Vec::with_capacity(vectors.len()); + for v in vectors.iter() { + match v { + None => dists.push(f32::NAN), + Some(v) => { let multivector = v.as_fixed_size_list(); - match distance_type { + if multivector.len() == 0 { + dists.push(f32::NAN); + continue; + } + + let sim = match distance_type { DistanceType::Hamming => { let query = query.as_primitive::().values(); query @@ -171,12 +177,12 @@ pub fn multivec_distance( ), _ => unreachable!("missed to check query type"), }, - } - }) - .unwrap_or(f32::NAN) - }) - .map(|sim| 1.0 - sim) - .collect(); + }; + + dists.push(1.0 - sim); + } + } + } Ok(dists) } @@ -204,3 +210,36 @@ where }) .sum() } + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow_array::types::Float32Type; + use arrow_array::{Float32Array, ListArray}; + use arrow_buffer::OffsetBuffer; + use arrow_schema::Field; + + #[test] + fn test_multivec_distance_empty_row_is_nan() { + let query: Arc = Arc::new(Float32Array::from_iter_values([1.0_f32, 2.0])); + + let dim = 2; + let values = FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(1.0_f32), Some(2.0)])], + dim, + ); + + // Two rows: first is empty list, second has one sub-vector. + let offsets = OffsetBuffer::from_lengths([0_usize, 1]); + let field = Arc::new(Field::new("item", values.data_type().clone(), true)); + let vectors = ListArray::try_new(field, offsets, Arc::new(values), None).unwrap(); + + let dists = multivec_distance(query.as_ref(), &vectors, DistanceType::Dot).unwrap(); + assert_eq!(dists.len(), 2); + assert!(dists[0].is_nan()); + assert_eq!(dists[1], -4.0); + } +} diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index b1ce7d47cd4..8c62541a519 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -12,7 +12,7 @@ use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type}; use arrow_array::{ builder::{ListBuilder, UInt32Builder}, cast::AsArray, - ArrayRef, RecordBatch, StringArray, + ArrayRef, BooleanArray, RecordBatch, StringArray, }; use arrow_array::{Array, Float32Array, UInt32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; @@ -231,8 +231,17 @@ impl ExecutionPlan for KNNVectorDistanceExec { let key = key.clone(); let column = column.clone(); async move { - compute_distance(key, dt, &column, batch?) + let batch = compute_distance(key, dt, &column, batch?) .await + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let distances = batch[DIST_COL].as_primitive::(); + let mask = BooleanArray::from_iter( + distances + .iter() + .map(|v| Some(v.map(|v| !v.is_nan()).unwrap_or(false))), + ); + arrow::compute::filter_record_batch(&batch, &mask) .map_err(|e| DataFusionError::Execution(e.to_string())) } })