diff --git a/rust/lance/benches/take.rs b/rust/lance/benches/take.rs index 68d9c963ef9..ec078d0f636 100644 --- a/rust/lance/benches/take.rs +++ b/rust/lance/benches/take.rs @@ -376,7 +376,7 @@ fn bench_sample(c: &mut Criterion) { let schema = schema.clone(); let dataset = dataset.clone(); async move { - dataset.sample(sample_size, &schema).await.unwrap(); + dataset.sample(sample_size, &schema, None).await.unwrap(); } }) }, diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 02f4b28e047..404bf6849d1 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -54,7 +54,7 @@ use roaring::RoaringBitmap; use rowids::get_row_id_index; use serde::{Deserialize, Serialize}; use std::borrow::Cow; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt::Debug; use std::ops::Range; use std::pin::Pin; @@ -1466,7 +1466,8 @@ impl Dataset { row_indices: &[u64], column: impl AsRef, ) -> Result> { - let row_addrs = row_offsets_to_row_addresses(self, row_indices).await?; + let fragments = self.get_fragments(); + let row_addrs = row_offsets_to_row_addresses(&fragments, row_indices).await?; blob::take_blobs_by_addresses(self, &row_addrs, column.as_ref()).await } @@ -1484,14 +1485,74 @@ impl Dataset { /// Randomly sample `n` rows from the dataset. /// + /// If `fragment_ids` is provided, sampling is limited to rows from those + /// fragments in the current dataset version. + /// /// The returned rows are in row-id order (not random order), which allows /// the underlying take operation to use an efficient sorted code path. - pub async fn sample(&self, n: usize, projection: &Schema) -> Result { + pub async fn sample( + &self, + n: usize, + projection: &Schema, + fragment_ids: Option<&[u32]>, + ) -> Result { use rand::seq::IteratorRandom; - let num_rows = self.count_rows(None).await?; - let mut ids = (0..num_rows as u64).choose_multiple(&mut rand::rng(), n); - ids.sort_unstable(); - self.take(&ids, projection.clone()).await + + match fragment_ids { + None => { + let num_rows = self.count_rows(None).await?; + let mut ids = (0..num_rows as u64).choose_multiple(&mut rand::rng(), n); + ids.sort_unstable(); + self.take(&ids, projection.clone()).await + } + Some(fragment_ids) => { + if fragment_ids.is_empty() { + return Err(Error::invalid_input( + "Dataset::sample does not accept an empty fragment_ids list".to_string(), + )); + } + + let selected_fragment_ids = fragment_ids.iter().copied().collect::>(); + let selected_fragments = self + .get_fragments() + .into_iter() + .filter(|fragment| selected_fragment_ids.contains(&(fragment.id() as u32))) + .collect::>(); + + if selected_fragments.len() != selected_fragment_ids.len() { + let present_fragment_ids = selected_fragments + .iter() + .map(|fragment| fragment.id() as u32) + .collect::>(); + let missing_fragment_ids = selected_fragment_ids + .into_iter() + .filter(|fragment_id| !present_fragment_ids.contains(fragment_id)) + .collect::>(); + return Err(Error::invalid_input(format!( + "Dataset::sample received fragment ids that are not part of the current dataset version: {missing_fragment_ids:?}", + ))); + } + + let num_rows = stream::iter(selected_fragments.iter().cloned()) + .map(|fragment| async move { fragment.count_rows(None).await }) + .buffer_unordered(16) + .try_fold(0_u64, |acc, rows| async move { Ok(acc + rows as u64) }) + .await?; + + let mut offsets = (0..num_rows).choose_multiple(&mut rand::rng(), n); + offsets.sort_unstable(); + + let row_addrs = row_offsets_to_row_addresses(&selected_fragments, &offsets).await?; + let dataset = Arc::new(self.clone()); + let projection = Arc::new( + ProjectionRequest::from(projection.clone()) + .into_projection_plan(dataset.clone())?, + ); + TakeBuilder::try_new_from_addresses(dataset, row_addrs, projection)? + .execute() + .await + } + } } /// Delete rows based on a predicate. diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 63bd7884879..8c7590c452f 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -2718,7 +2718,7 @@ impl Scanner { TakeOperation::RowAddrs(addrs) => self.u64s_as_take_input(addrs), TakeOperation::RowOffsets(offsets) => { let mut addrs = - row_offsets_to_row_addresses(self.dataset.as_ref(), &offsets).await?; + row_offsets_to_row_addresses(&self.dataset.get_fragments(), &offsets).await?; addrs.retain(|addr| *addr != RowAddress::TOMBSTONE_ROW); self.u64s_as_take_input(addrs) } diff --git a/rust/lance/src/dataset/take.rs b/rust/lance/src/dataset/take.rs index 73625a171e0..68121410f01 100644 --- a/rust/lance/src/dataset/take.rs +++ b/rust/lance/src/dataset/take.rs @@ -44,11 +44,9 @@ use super::{Dataset, fragment::FileFragment, scanner::DatasetRecordBatchStream}; /// /// If any offsets are beyond the end of the dataset, they will be mapped to a tombstone row address. pub(super) async fn row_offsets_to_row_addresses( - dataset: &Dataset, + fragments: &[FileFragment], row_indices: &[u64], ) -> Result> { - let fragments = dataset.get_fragments(); - let mut perm = permutation::sort(row_indices); let sorted_offsets = perm.apply_slice(row_indices); @@ -115,7 +113,8 @@ pub async fn take( } // First, convert the dataset offsets into row addresses - let addrs = row_offsets_to_row_addresses(dataset, offsets).await?; + let fragments = dataset.get_fragments(); + let addrs = row_offsets_to_row_addresses(&fragments, offsets).await?; let builder = TakeBuilder::try_new_from_addresses( Arc::new(dataset.clone()), diff --git a/rust/lance/src/dataset/tests/dataset_io.rs b/rust/lance/src/dataset/tests/dataset_io.rs index e438e0801ea..2c094e7dc5c 100644 --- a/rust/lance/src/dataset/tests/dataset_io.rs +++ b/rust/lance/src/dataset/tests/dataset_io.rs @@ -1437,6 +1437,118 @@ async fn test_fast_count_rows( ); } +#[rstest] +#[tokio::test] +async fn test_sample_with_fragment_ids( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, +) { + let test_uri = TempStrDir::default(); + let data = gen_batch() + .col("i", array::step::()) + .into_reader_rows(RowCount::from(12), BatchCount::from(1)); + let mut dataset = Dataset::write( + data, + &test_uri, + Some(WriteParams { + max_rows_per_file: 4, + max_rows_per_group: 2, + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + + dataset.delete("i IN (1, 9)").await.unwrap(); + + let projection = dataset.schema().project(&["i"]).unwrap(); + let sampled = dataset + .sample(8, &projection, Some(&[0, 0, 2])) + .await + .unwrap(); + let sampled_values = sampled + .column_by_name("i") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(); + + assert_eq!(sampled_values, vec![0, 2, 3, 8, 10, 11]); +} + +#[rstest] +#[tokio::test] +async fn test_sample_with_empty_fragment_ids_rejected( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, +) { + let test_uri = TempStrDir::default(); + let data = gen_batch() + .col("i", array::step::()) + .into_reader_rows(RowCount::from(8), BatchCount::from(1)); + let dataset = Dataset::write( + data, + &test_uri, + Some(WriteParams { + max_rows_per_file: 4, + max_rows_per_group: 2, + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + + let projection = dataset.schema().project(&["i"]).unwrap(); + let err = dataset.sample(1, &projection, Some(&[])).await.unwrap_err(); + + assert!(matches!(err, Error::InvalidInput { .. })); + assert!( + err.to_string() + .contains("does not accept an empty fragment_ids list") + ); +} + +#[rstest] +#[tokio::test] +async fn test_sample_with_unknown_fragment_ids_rejected( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, +) { + let test_uri = TempStrDir::default(); + let data = gen_batch() + .col("i", array::step::()) + .into_reader_rows(RowCount::from(8), BatchCount::from(1)); + let dataset = Dataset::write( + data, + &test_uri, + Some(WriteParams { + max_rows_per_file: 4, + max_rows_per_group: 2, + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + + let projection = dataset.schema().project(&["i"]).unwrap(); + let err = dataset + .sample(1, &projection, Some(&[0, 999])) + .await + .unwrap_err(); + + assert!(matches!(err, Error::InvalidInput { .. })); + assert!( + err.to_string() + .contains("not part of the current dataset version") + ); + assert!(err.to_string().contains("999")); +} + #[rstest] #[tokio::test] async fn test_bfloat16_roundtrip( diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index cfc1e8f0dca..bdbdfda3cda 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -3317,7 +3317,7 @@ mod tests { // Sample 2048 random indices and then paste on a column of 9999999's let some_indices = ds - .sample(2048, &(&just_index_col).try_into().unwrap()) + .sample(2048, &(&just_index_col).try_into().unwrap(), None) .await .unwrap(); let some_indices = some_indices.column(0).clone(); diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 83e010dc1a4..b20d659d6f3 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -96,7 +96,7 @@ async fn estimate_multivector_vectors_per_row( // Try a few random samples first (fast path). let sample_batch_size = std::cmp::min(64, num_rows); for _ in 0..8 { - let batch = dataset.sample(sample_batch_size, &projection).await?; + let batch = dataset.sample(sample_batch_size, &projection, None).await?; let array = get_column_from_batch(&batch, column)?; let list_array = array.as_list::(); for i in 0..list_array.len() {