Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/lance/benches/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ fn bench_sample(c: &mut Criterion) {
let schema = schema.clone();
let dataset = dataset.clone();
async move {
dataset.sample(sample_size, &schema).await.unwrap();
dataset.sample(sample_size, &schema, None).await.unwrap();
}
})
},
Expand Down
75 changes: 68 additions & 7 deletions rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use roaring::RoaringBitmap;
use rowids::get_row_id_index;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::fmt::Debug;
use std::ops::Range;
use std::pin::Pin;
Expand Down Expand Up @@ -1466,7 +1466,8 @@ impl Dataset {
row_indices: &[u64],
column: impl AsRef<str>,
) -> Result<Vec<BlobFile>> {
let row_addrs = row_offsets_to_row_addresses(self, row_indices).await?;
let fragments = self.get_fragments();
let row_addrs = row_offsets_to_row_addresses(&fragments, row_indices).await?;
blob::take_blobs_by_addresses(self, &row_addrs, column.as_ref()).await
}

Expand All @@ -1484,14 +1485,74 @@ impl Dataset {

/// Randomly sample `n` rows from the dataset.
///
/// If `fragment_ids` is provided, sampling is limited to rows from those
/// fragments in the current dataset version.
///
/// The returned rows are in row-id order (not random order), which allows
/// the underlying take operation to use an efficient sorted code path.
pub async fn sample(&self, n: usize, projection: &Schema) -> Result<RecordBatch> {
pub async fn sample(
&self,
n: usize,
projection: &Schema,
fragment_ids: Option<&[u32]>,
) -> Result<RecordBatch> {
use rand::seq::IteratorRandom;
let num_rows = self.count_rows(None).await?;
let mut ids = (0..num_rows as u64).choose_multiple(&mut rand::rng(), n);
ids.sort_unstable();
self.take(&ids, projection.clone()).await

match fragment_ids {
None => {
let num_rows = self.count_rows(None).await?;
let mut ids = (0..num_rows as u64).choose_multiple(&mut rand::rng(), n);
ids.sort_unstable();
self.take(&ids, projection.clone()).await
}
Some(fragment_ids) => {
if fragment_ids.is_empty() {
return Err(Error::invalid_input(
"Dataset::sample does not accept an empty fragment_ids list".to_string(),
));
}

let selected_fragment_ids = fragment_ids.iter().copied().collect::<BTreeSet<_>>();
let selected_fragments = self
.get_fragments()
.into_iter()
.filter(|fragment| selected_fragment_ids.contains(&(fragment.id() as u32)))
.collect::<Vec<_>>();

if selected_fragments.len() != selected_fragment_ids.len() {
let present_fragment_ids = selected_fragments
.iter()
.map(|fragment| fragment.id() as u32)
.collect::<HashSet<_>>();
let missing_fragment_ids = selected_fragment_ids
.into_iter()
.filter(|fragment_id| !present_fragment_ids.contains(fragment_id))
.collect::<Vec<_>>();
return Err(Error::invalid_input(format!(
"Dataset::sample received fragment ids that are not part of the current dataset version: {missing_fragment_ids:?}",
)));
}

let num_rows = stream::iter(selected_fragments.iter().cloned())
.map(|fragment| async move { fragment.count_rows(None).await })
.buffer_unordered(16)
.try_fold(0_u64, |acc, rows| async move { Ok(acc + rows as u64) })
.await?;

let mut offsets = (0..num_rows).choose_multiple(&mut rand::rng(), n);
offsets.sort_unstable();

let row_addrs = row_offsets_to_row_addresses(&selected_fragments, &offsets).await?;
let dataset = Arc::new(self.clone());
let projection = Arc::new(
ProjectionRequest::from(projection.clone())
.into_projection_plan(dataset.clone())?,
);
TakeBuilder::try_new_from_addresses(dataset, row_addrs, projection)?
.execute()
.await
}
}
}

/// Delete rows based on a predicate.
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2718,7 +2718,7 @@ impl Scanner {
TakeOperation::RowAddrs(addrs) => self.u64s_as_take_input(addrs),
TakeOperation::RowOffsets(offsets) => {
let mut addrs =
row_offsets_to_row_addresses(self.dataset.as_ref(), &offsets).await?;
row_offsets_to_row_addresses(&self.dataset.get_fragments(), &offsets).await?;
addrs.retain(|addr| *addr != RowAddress::TOMBSTONE_ROW);
self.u64s_as_take_input(addrs)
}
Expand Down
7 changes: 3 additions & 4 deletions rust/lance/src/dataset/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ use super::{Dataset, fragment::FileFragment, scanner::DatasetRecordBatchStream};
///
/// If any offsets are beyond the end of the dataset, they will be mapped to a tombstone row address.
pub(super) async fn row_offsets_to_row_addresses(
dataset: &Dataset,
fragments: &[FileFragment],
row_indices: &[u64],
) -> Result<Vec<u64>> {
let fragments = dataset.get_fragments();

let mut perm = permutation::sort(row_indices);
let sorted_offsets = perm.apply_slice(row_indices);

Expand Down Expand Up @@ -115,7 +113,8 @@ pub async fn take(
}

// First, convert the dataset offsets into row addresses
let addrs = row_offsets_to_row_addresses(dataset, offsets).await?;
let fragments = dataset.get_fragments();
let addrs = row_offsets_to_row_addresses(&fragments, offsets).await?;

let builder = TakeBuilder::try_new_from_addresses(
Arc::new(dataset.clone()),
Expand Down
112 changes: 112 additions & 0 deletions rust/lance/src/dataset/tests/dataset_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,118 @@ async fn test_fast_count_rows(
);
}

#[rstest]
#[tokio::test]
async fn test_sample_with_fragment_ids(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_uri = TempStrDir::default();
let data = gen_batch()
.col("i", array::step::<Int32Type>())
.into_reader_rows(RowCount::from(12), BatchCount::from(1));
let mut dataset = Dataset::write(
data,
&test_uri,
Some(WriteParams {
max_rows_per_file: 4,
max_rows_per_group: 2,
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();

dataset.delete("i IN (1, 9)").await.unwrap();

let projection = dataset.schema().project(&["i"]).unwrap();
let sampled = dataset
.sample(8, &projection, Some(&[0, 0, 2]))
.await
.unwrap();
let sampled_values = sampled
.column_by_name("i")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.values()
.to_vec();

assert_eq!(sampled_values, vec![0, 2, 3, 8, 10, 11]);
}

#[rstest]
#[tokio::test]
async fn test_sample_with_empty_fragment_ids_rejected(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_uri = TempStrDir::default();
let data = gen_batch()
.col("i", array::step::<Int32Type>())
.into_reader_rows(RowCount::from(8), BatchCount::from(1));
let dataset = Dataset::write(
data,
&test_uri,
Some(WriteParams {
max_rows_per_file: 4,
max_rows_per_group: 2,
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
Comment thread
Xuanwo marked this conversation as resolved.

let projection = dataset.schema().project(&["i"]).unwrap();
let err = dataset.sample(1, &projection, Some(&[])).await.unwrap_err();

assert!(matches!(err, Error::InvalidInput { .. }));
assert!(
err.to_string()
.contains("does not accept an empty fragment_ids list")
);
}

#[rstest]
#[tokio::test]
async fn test_sample_with_unknown_fragment_ids_rejected(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_uri = TempStrDir::default();
let data = gen_batch()
.col("i", array::step::<Int32Type>())
.into_reader_rows(RowCount::from(8), BatchCount::from(1));
let dataset = Dataset::write(
data,
&test_uri,
Some(WriteParams {
max_rows_per_file: 4,
max_rows_per_group: 2,
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
Comment thread
Xuanwo marked this conversation as resolved.

let projection = dataset.schema().project(&["i"]).unwrap();
let err = dataset
.sample(1, &projection, Some(&[0, 999]))
.await
.unwrap_err();

assert!(matches!(err, Error::InvalidInput { .. }));
assert!(
err.to_string()
.contains("not part of the current dataset version")
);
assert!(err.to_string().contains("999"));
}

#[rstest]
#[tokio::test]
async fn test_bfloat16_roundtrip(
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3317,7 +3317,7 @@ mod tests {

// Sample 2048 random indices and then paste on a column of 9999999's
let some_indices = ds
.sample(2048, &(&just_index_col).try_into().unwrap())
.sample(2048, &(&just_index_col).try_into().unwrap(), None)
.await
.unwrap();
let some_indices = some_indices.column(0).clone();
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async fn estimate_multivector_vectors_per_row(
// Try a few random samples first (fast path).
let sample_batch_size = std::cmp::min(64, num_rows);
for _ in 0..8 {
let batch = dataset.sample(sample_batch_size, &projection).await?;
let batch = dataset.sample(sample_batch_size, &projection, None).await?;
let array = get_column_from_batch(&batch, column)?;
let list_array = array.as_list::<i32>();
for i in 0..list_array.len() {
Expand Down
Loading