Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 2 additions & 37 deletions rust/lance-index/src/vector/v3/shuffler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::sync::Arc;
use arrow::{array::AsArray, compute::sort_to_indices};
use arrow_array::{RecordBatch, UInt32Array};
use arrow_schema::Schema;
use future::try_join_all;
use futures::prelude::*;
use lance_arrow::{RecordBatchExt, SchemaExt};
use lance_core::{
Expand Down Expand Up @@ -69,7 +68,6 @@ pub struct IvfShuffler {
num_partitions: usize,

// options
buffer_size: usize,
precomputed_shuffle_buffers: Option<Vec<String>>,
}

Expand All @@ -79,16 +77,10 @@ impl IvfShuffler {
object_store: Arc::new(ObjectStore::local()),
output_dir,
num_partitions,
buffer_size: 4096,
precomputed_shuffle_buffers: None,
}
}

pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
self.buffer_size = buffer_size;
self
}

pub fn with_precomputed_shuffle_buffers(
mut self,
precomputed_shuffle_buffers: Option<Vec<String>>,
Expand Down Expand Up @@ -163,43 +155,16 @@ impl Shuffler for IvfShuffler {
})
.buffered(get_num_compute_intensive_cpus());

// part_id: | 0 | 1 | 3 |
// partition_buffers: |[batch,batch,..]|[batch,batch,..]|[batch,batch,..]|
let mut partition_buffers = vec![Vec::new(); num_partitions];

let mut counter = 0;
let mut total_loss = 0.0;
while let Some(shuffled) = parallel_sort_stream.next().await {
let (shuffled, loss) = shuffled?;
total_loss += loss;

for (part_id, batches) in shuffled.into_iter().enumerate() {
let part_batches = &mut partition_buffers[part_id];
part_batches.extend(batches);
}

counter += 1;

// do flush
if counter % self.buffer_size == 0 {
let mut futs = vec![];
for (part_id, writer) in writers.iter_mut().enumerate() {
let batches = &partition_buffers[part_id];
if !batches.is_empty() {
partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
futs.push(writer.write_batches(batches.iter()));
writers[part_id].write_batches(batches.iter()).await?;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do this in a follow-up but it might be nice to still do all the writes in parallel. E.g. keep the futs Vec. shuffled is a Vec and not any kind of stream / iterator so the data is all in memory already (I think the important point is getting rid of the if counter % self.buffer_size == 0)

let mut futs = vec![];
if !batches.is_empty() {
    futs.push(writers[part_id].write_batches(batches.iter()));
}
try_join_all(futs).await?;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, updated

}
try_join_all(futs).await?;

partition_buffers.iter_mut().for_each(|b| b.clear());
}
}

// final flush
for (part_id, batches) in partition_buffers.into_iter().enumerate() {
let writer = &mut writers[part_id];
partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
for batch in batches.iter() {
writer.write_batch(batch).await?;
}
}

Expand Down