From 1e33f46d8071d8658c60d3af7a8ff717c6da91ce Mon Sep 17 00:00:00 2001 From: Vivek Date: Mon, 16 Feb 2026 18:10:29 -0800 Subject: [PATCH 1/2] feat: add progress monitoring via callbacks for inverted indexes --- rust/lance-index/src/scalar/inverted.rs | 39 +++- .../src/scalar/inverted/builder.rs | 214 ++++++++++++++++-- .../lance-index/src/scalar/inverted/merger.rs | 23 +- rust/lance-index/src/scalar/registry.rs | 23 ++ rust/lance/src/index.rs | 1 + rust/lance/src/index/create.rs | 3 + rust/lance/src/index/scalar.rs | 13 +- 7 files changed, 296 insertions(+), 20 deletions(-) diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index e8644600513..500f8214a0b 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -28,6 +28,7 @@ use lance_core::Error; use snafu::location; use crate::pbold; +use crate::progress::{noop_progress, IndexBuildProgress}; use crate::{ frag_reuse::FragReuseIndex, scalar::{ @@ -48,6 +49,7 @@ impl InvertedIndexPlugin { index_store: &dyn IndexStore, params: InvertedIndexParams, fragment_ids: Option>, + progress: Arc, ) -> Result { let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| { if !frag_ids.is_empty() { @@ -62,7 +64,8 @@ impl InvertedIndexPlugin { let details = pbold::InvertedIndexDetails::try_from(¶ms)?; let mut inverted_index = - InvertedIndexBuilder::new_with_fragment_mask(params, fragment_mask); + InvertedIndexBuilder::new_with_fragment_mask(params, fragment_mask) + .with_progress(progress); inverted_index.update(data, index_store).await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&details).unwrap(), @@ -180,8 +183,38 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { source: "must provide training request created by new_training_request".into(), location: location!(), })?; - Self::train_inverted_index(data, index_store, request.parameters.clone(), fragment_ids) - .await + Self::train_inverted_index( + data, + index_store, + request.parameters.clone(), + fragment_ids, + noop_progress(), + ) + .await + } + + async fn train_index_with_progress( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + progress: Arc, + ) -> Result { + let request = (request as Box) + .downcast::() + .map_err(|_| Error::InvalidInput { + source: "must provide training request created by new_training_request".into(), + location: location!(), + })?; + Self::train_inverted_index( + data, + index_store, + request.parameters.clone(), + fragment_ids, + progress, + ) + .await } /// Load an index from storage diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 023e9c7d252..1e7e627b28f 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -12,6 +12,7 @@ use crate::scalar::inverted::tokenizer::lance_tokenizer::LanceTokenizer; use crate::scalar::lance_format::LanceIndexStore; use crate::scalar::IndexStore; use crate::vector::graph::OrderedFloat; +use crate::{progress::noop_progress, progress::IndexBuildProgress}; use arrow::array::AsArray; use arrow::datatypes; use arrow_array::{Array, RecordBatch, UInt64Array}; @@ -80,6 +81,7 @@ pub struct InvertedIndexBuilder { _tmpdir: TempDir, local_store: Arc, src_store: Arc, + progress: Arc, } impl InvertedIndexBuilder { @@ -126,9 +128,15 @@ impl InvertedIndexBuilder { src_store, token_set_format, fragment_mask, + progress: noop_progress(), } } + pub fn with_progress(mut self, progress: Arc) -> Self { + self.progress = progress; + self + } + pub async fn update( &mut self, new_data: SendableRecordBatchStream, @@ -147,7 +155,11 @@ impl InvertedIndexBuilder { let new_data = document_input(new_data, doc_col)?; + self.progress + .stage_start("tokenize_docs", None, "rows") + .await?; self.update_index(new_data).await?; + self.progress.stage_complete("tokenize_docs").await?; self.write(dest_store).await?; Ok(()) } @@ -209,6 +221,9 @@ impl InvertedIndexBuilder { while let Some(num_rows) = stream.try_next().await? { total_num_rows += num_rows; if total_num_rows >= last_num_rows + 1_000_000 { + self.progress + .stage_progress("tokenize_docs", total_num_rows as u64) + .await?; log::debug!( "indexed {} documents, elapsed: {:?}, speed: {}rows/s", total_num_rows, @@ -218,6 +233,11 @@ impl InvertedIndexBuilder { last_num_rows = total_num_rows; } } + if total_num_rows > last_num_rows { + self.progress + .stage_progress("tokenize_docs", total_num_rows as u64) + .await?; + } // drop the sender to stop receivers drop(stream); debug_assert_eq!(sender.sender_count(), 1); @@ -306,6 +326,36 @@ impl InvertedIndexBuilder { Ok(()) } + async fn write_metadata_with_progress( + &self, + dest_store: &dyn IndexStore, + partitions: &[u64], + ) -> Result<()> { + let total = if self.fragment_mask.is_none() { + Some(1) + } else { + Some(partitions.len() as u64) + }; + self.progress + .stage_start("write_metadata", total, "files") + .await?; + if self.fragment_mask.is_none() { + self.write_metadata(dest_store, partitions).await?; + self.progress.stage_progress("write_metadata", 1).await?; + } else { + let mut completed = 0; + for &partition_id in partitions { + self.write_part_metadata(dest_store, partition_id).await?; + completed += 1; + self.progress + .stage_progress("write_metadata", completed) + .await?; + } + } + self.progress.stage_complete("write_metadata").await?; + Ok(()) + } + async fn write(&self, dest_store: &dyn IndexStore) -> Result<()> { if self.params.skip_merge { let mut partitions = @@ -314,6 +364,14 @@ impl InvertedIndexBuilder { partitions.extend_from_slice(&self.new_partitions); partitions.sort_unstable(); + self.progress + .stage_start( + "copy_partitions", + Some(partitions.len() as u64), + "partitions", + ) + .await?; + let mut copied = 0; for part in self.partitions.iter() { self.src_store .copy_index_file(&token_file_path(*part), dest_store) @@ -324,6 +382,10 @@ impl InvertedIndexBuilder { self.src_store .copy_index_file(&doc_file_path(*part), dest_store) .await?; + copied += 1; + self.progress + .stage_progress("copy_partitions", copied) + .await?; } for part in self.new_partitions.iter() { self.local_store @@ -335,15 +397,15 @@ impl InvertedIndexBuilder { self.local_store .copy_index_file(&doc_file_path(*part), dest_store) .await?; + copied += 1; + self.progress + .stage_progress("copy_partitions", copied) + .await?; } + self.progress.stage_complete("copy_partitions").await?; - if self.fragment_mask.is_none() { - self.write_metadata(dest_store, &partitions).await?; - } else { - for &partition_id in &partitions { - self.write_part_metadata(dest_store, partition_id).await?; - } - } + self.write_metadata_with_progress(dest_store, &partitions) + .await?; return Ok(()); } @@ -357,21 +419,25 @@ impl InvertedIndexBuilder { .map(|part| PartitionSource::new(self.local_store.clone(), *part)), ) .collect::>(); + self.progress + .stage_start( + "merge_partitions", + Some(partitions.len() as u64), + "partitions", + ) + .await?; let mut merger = SizeBasedMerger::new( dest_store, partitions, *LANCE_FTS_TARGET_SIZE << 20, self.token_set_format, + self.progress.clone(), ); let partitions = merger.merge().await?; + self.progress.stage_complete("merge_partitions").await?; - if self.fragment_mask.is_none() { - self.write_metadata(dest_store, &partitions).await?; - } else { - for &partition_id in &partitions { - self.write_part_metadata(dest_store, partition_id).await?; - } - } + self.write_metadata_with_progress(dest_store, &partitions) + .await?; Ok(()) } } @@ -1240,6 +1306,7 @@ pub fn document_input( mod tests { use super::*; use crate::metrics::NoOpMetricsCollector; + use crate::progress::IndexBuildProgress; use crate::scalar::{IndexReader, IndexWriter}; use arrow_array::{RecordBatch, StringArray, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; @@ -1252,6 +1319,7 @@ mod tests { use snafu::location; use std::any::Any; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; + use tokio::sync::Mutex; fn make_doc_batch(doc: &str, row_id: u64) -> RecordBatch { let schema = Arc::new(Schema::new(vec![ @@ -1502,4 +1570,122 @@ mod tests { Ok(()) } + + #[derive(Debug, Default)] + struct RecordingProgress { + events: Mutex>, + } + + #[async_trait] + impl IndexBuildProgress for RecordingProgress { + async fn stage_start(&self, stage: &str, total: Option, _unit: &str) -> Result<()> { + self.events.lock().await.push(( + "start".to_string(), + stage.to_string(), + total.unwrap_or(0), + )); + Ok(()) + } + + async fn stage_progress(&self, stage: &str, completed: u64) -> Result<()> { + self.events + .lock() + .await + .push(("progress".to_string(), stage.to_string(), completed)); + Ok(()) + } + + async fn stage_complete(&self, stage: &str) -> Result<()> { + self.events + .lock() + .await + .push(("complete".to_string(), stage.to_string(), 0)); + Ok(()) + } + } + + #[tokio::test] + async fn test_builder_reports_progress_stages() -> Result<()> { + let index_dir = TempDir::default(); + let store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + index_dir.obj_path(), + Arc::new(LanceCache::no_cache()), + )); + + let schema = Arc::new(Schema::new(vec![ + Field::new("doc", DataType::Utf8, true), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + let docs = Arc::new(StringArray::from(vec![ + Some("hello world"), + Some("goodbye world"), + ])); + let row_ids = Arc::new(UInt64Array::from(vec![0u64, 1u64])); + let batch = RecordBatch::try_new(schema.clone(), vec![docs, row_ids])?; + let stream = RecordBatchStreamAdapter::new(schema, stream::iter(vec![Ok(batch)])); + let stream = Box::pin(stream); + + let progress = Arc::new(RecordingProgress::default()); + let mut builder = + InvertedIndexBuilder::new(InvertedIndexParams::default().skip_merge(true)) + .with_progress(progress.clone()); + builder.update(stream, store.as_ref()).await?; + + let events = progress.events.lock().await.clone(); + let tags = events + .iter() + .map(|(kind, stage, _)| format!("{kind}:{stage}")) + .collect::>(); + + let tokenize_start = tags + .iter() + .position(|e| e == "start:tokenize_docs") + .expect("missing tokenize_docs start"); + let tokenize_complete = tags + .iter() + .position(|e| e == "complete:tokenize_docs") + .expect("missing tokenize_docs complete"); + let copy_start = tags + .iter() + .position(|e| e == "start:copy_partitions") + .expect("missing copy_partitions start"); + let copy_complete = tags + .iter() + .position(|e| e == "complete:copy_partitions") + .expect("missing copy_partitions complete"); + let metadata_start = tags + .iter() + .position(|e| e == "start:write_metadata") + .expect("missing write_metadata start"); + let metadata_complete = tags + .iter() + .position(|e| e == "complete:write_metadata") + .expect("missing write_metadata complete"); + + assert!(tokenize_start < tokenize_complete); + assert!(tokenize_complete < copy_start); + assert!(copy_start < copy_complete); + assert!(copy_complete < metadata_start); + assert!(metadata_start < metadata_complete); + + assert!( + tags.iter().any(|e| e == "progress:tokenize_docs"), + "expected progress callback for tokenize_docs" + ); + assert!( + tags.iter().any(|e| e == "progress:copy_partitions"), + "expected progress callback for copy_partitions" + ); + assert!( + tags.iter().any(|e| e == "progress:write_metadata"), + "expected progress callback for write_metadata" + ); + assert!( + !tags.iter().any(|e| e == "start:merge_partitions"), + "merge_partitions should not run in skip_merge mode" + ); + + Ok(()) + } } diff --git a/rust/lance-index/src/scalar/inverted/merger.rs b/rust/lance-index/src/scalar/inverted/merger.rs index d1e98ac84ae..21b46d0d453 100644 --- a/rust/lance-index/src/scalar/inverted/merger.rs +++ b/rust/lance-index/src/scalar/inverted/merger.rs @@ -5,7 +5,9 @@ use fst::Streamer; use futures::{stream, StreamExt, TryStreamExt}; use lance_core::{cache::LanceCache, utils::tokio::get_num_compute_intensive_cpus, Error, Result}; use snafu::location; +use std::sync::Arc; +use crate::progress::IndexBuildProgress; use crate::scalar::IndexStore; use super::{ @@ -51,6 +53,7 @@ pub struct SizeBasedMerger<'a> { with_position: Option, target_size: u64, token_set_format: TokenSetFormat, + progress: Arc, builder: Option, next_id: u64, partitions: Vec, @@ -66,6 +69,7 @@ impl<'a> SizeBasedMerger<'a> { input: Vec, target_size: u64, token_set_format: TokenSetFormat, + progress: Arc, ) -> Self { let max_id = input.iter().map(|p| p.id).max().unwrap_or(0); @@ -75,6 +79,7 @@ impl<'a> SizeBasedMerger<'a> { with_position: None, target_size, token_set_format, + progress, builder: None, next_id: max_id + 1, partitions: Vec::new(), @@ -215,6 +220,7 @@ impl<'a> SizeBasedMerger<'a> { impl Merger for SizeBasedMerger<'_> { async fn merge(&mut self) -> Result> { if self.input.len() <= 1 { + let mut completed = 0; for part in self.input.iter() { part.store .copy_index_file(&token_file_path(part.id), self.dest_store) @@ -225,6 +231,10 @@ impl Merger for SizeBasedMerger<'_> { part.store .copy_index_file(&doc_file_path(part.id), self.dest_store) .await?; + completed += 1; + self.progress + .stage_progress("merge_partitions", completed) + .await?; } return Ok(self.input.iter().map(|p| p.id).collect()); @@ -258,6 +268,9 @@ impl Merger for SizeBasedMerger<'_> { while let Some(part) = stream.try_next().await? { idx += 1; self.merge_partition(part, &mut estimated_size).await?; + self.progress + .stage_progress("merge_partitions", idx as u64) + .await?; log::info!( "merged {}/{} partitions in {:?}", idx, @@ -328,6 +341,7 @@ mod tests { ], u64::MAX, token_set_format, + crate::progress::noop_progress(), ); let merged_partitions = merger.merge().await?; assert_eq!(merged_partitions, vec![2]); @@ -387,8 +401,13 @@ mod tests { sources.push(PartitionSource::new(src_store.clone(), id)); } - let mut merger = - SizeBasedMerger::new(dest_store.as_ref(), sources, u64::MAX, token_set_format); + let mut merger = SizeBasedMerger::new( + dest_store.as_ref(), + sources, + u64::MAX, + token_set_format, + crate::progress::noop_progress(), + ); let merged_partitions = merger.merge().await?; assert_eq!(merged_partitions, vec![num_parts as u64]); diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 76b088518e3..452569f92b3 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -8,6 +8,7 @@ use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; use lance_core::{cache::LanceCache, Result}; +use crate::progress::IndexBuildProgress; use crate::registry::IndexPluginRegistry; use crate::{ frag_reuse::FragReuseIndex, @@ -116,6 +117,28 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { fragment_ids: Option>, ) -> Result; + /// Train a new index with progress callbacks. + /// + /// Plugins can override this for index-specific stages. The default implementation + /// wraps [`Self::train_index`] with a generic scalar training stage. + async fn train_index_with_progress( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + progress: Arc, + ) -> Result { + progress.stage_start("train_scalar", None, "").await?; + let result = self + .train_index(data, index_store, request, fragment_ids) + .await; + if result.is_ok() { + progress.stage_complete("train_scalar").await?; + } + result + } + /// A short name for the index /// /// This is a friendly name for display purposes and also can be used as an alias for diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 0f308e9c1d9..6d846cff4a5 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -336,6 +336,7 @@ pub(crate) async fn remap_index( &new_store, inverted_index.params().clone(), None, + Arc::new(NoopIndexBuildProgress), ) .await? } else { diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index 095f9da0980..11c2dbcc354 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -251,6 +251,7 @@ impl<'a> CreateIndexBuilder<'a> { train, self.fragments.clone(), preprocesssed_data, + self.progress.clone(), ) .await? } @@ -272,6 +273,7 @@ impl<'a> CreateIndexBuilder<'a> { train, self.fragments.clone(), None, + self.progress.clone(), ) .await? } @@ -296,6 +298,7 @@ impl<'a> CreateIndexBuilder<'a> { train, self.fragments.clone(), None, + self.progress.clone(), ) .await? } diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 11a2c22b67a..1c04bfdd96f 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -24,6 +24,7 @@ use lance_index::metrics::{MetricsCollector, NoOpMetricsCollector}; use lance_index::pbold::{ BTreeIndexDetails, BitmapIndexDetails, InvertedIndexDetails, LabelListIndexDetails, }; +use lance_index::progress::IndexBuildProgress; use lance_index::registry::IndexPluginRegistry; use lance_index::scalar::inverted::METADATA_FILE; use lance_index::scalar::registry::{ @@ -250,6 +251,7 @@ impl IndexDetails { } /// Build a Scalar Index (returns details to store in the manifest) +#[allow(clippy::too_many_arguments)] #[instrument(level = "debug", skip_all)] pub(super) async fn build_scalar_index( dataset: &Dataset, @@ -259,6 +261,7 @@ pub(super) async fn build_scalar_index( train: bool, fragment_ids: Option>, preprocessed_data: Option, + progress: Arc, ) -> Result { let field = dataset.schema().field(column).ok_or(Error::InvalidInput { source: format!("No column with name {}", column).into(), @@ -272,6 +275,7 @@ pub(super) async fn build_scalar_index( let training_request = plugin.new_training_request(params.params.as_deref().unwrap_or("{}"), &field)?; + progress.stage_start("load_data", None, "rows").await?; let training_data = match preprocessed_data { Some(preprocessed_data) => preprocessed_data, None => { @@ -286,9 +290,16 @@ pub(super) async fn build_scalar_index( .await? } }; + progress.stage_complete("load_data").await?; plugin - .train_index(training_data, &index_store, training_request, fragment_ids) + .train_index_with_progress( + training_data, + &index_store, + training_request, + fragment_ids, + progress, + ) .await } From 124fb202bbc91be1eccdd33d75e505d777a96707 Mon Sep 17 00:00:00 2001 From: Vivek Date: Tue, 17 Feb 2026 17:37:20 -0800 Subject: [PATCH 2/2] review feedback --- rust/lance-index/benches/geo.rs | 1 + rust/lance-index/src/scalar/bitmap.rs | 1 + rust/lance-index/src/scalar/bloomfilter.rs | 1 + rust/lance-index/src/scalar/btree.rs | 1 + rust/lance-index/src/scalar/inverted.rs | 25 +------- .../src/scalar/inverted/builder.rs | 57 ++++++++++++------- rust/lance-index/src/scalar/json.rs | 9 ++- rust/lance-index/src/scalar/label_list.rs | 3 +- rust/lance-index/src/scalar/lance_format.rs | 24 +++++++- rust/lance-index/src/scalar/ngram.rs | 1 + rust/lance-index/src/scalar/registry.rs | 23 +------- rust/lance-index/src/scalar/rtree.rs | 2 + rust/lance-index/src/scalar/zonemap.rs | 1 + rust/lance/src/index/scalar.rs | 2 +- 14 files changed, 79 insertions(+), 72 deletions(-) diff --git a/rust/lance-index/benches/geo.rs b/rust/lance-index/benches/geo.rs index dea4c4d811a..a3f896ffdcc 100644 --- a/rust/lance-index/benches/geo.rs +++ b/rust/lance-index/benches/geo.rs @@ -78,6 +78,7 @@ async fn build_rtree( store.as_ref(), Box::new(RTreeTrainingRequest::default()), None, + lance_index::progress::noop_progress(), ) .await?; diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 4fb9fc3334c..647582cda44 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -803,6 +803,7 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { index_store: &dyn IndexStore, _request: Box, fragment_ids: Option>, + _progress: Arc, ) -> Result { if fragment_ids.is_some() { return Err(Error::InvalidInput { diff --git a/rust/lance-index/src/scalar/bloomfilter.rs b/rust/lance-index/src/scalar/bloomfilter.rs index 3057323b5da..d25e1f18c78 100644 --- a/rust/lance-index/src/scalar/bloomfilter.rs +++ b/rust/lance-index/src/scalar/bloomfilter.rs @@ -1110,6 +1110,7 @@ impl ScalarIndexPlugin for BloomFilterIndexPlugin { index_store: &dyn IndexStore, request: Box, fragment_ids: Option>, + _progress: Arc, ) -> Result { if fragment_ids.is_some() { return Err(Error::InvalidInput { diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index bf44a9d5b97..f939ecbffed 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -2543,6 +2543,7 @@ impl ScalarIndexPlugin for BTreeIndexPlugin { index_store: &dyn IndexStore, request: Box, fragment_ids: Option>, + _progress: Arc, ) -> Result { let request = request .as_any() diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 500f8214a0b..e7d1913e6f7 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -28,7 +28,7 @@ use lance_core::Error; use snafu::location; use crate::pbold; -use crate::progress::{noop_progress, IndexBuildProgress}; +use crate::progress::IndexBuildProgress; use crate::{ frag_reuse::FragReuseIndex, scalar::{ @@ -176,29 +176,6 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { index_store: &dyn IndexStore, request: Box, fragment_ids: Option>, - ) -> Result { - let request = (request as Box) - .downcast::() - .map_err(|_| Error::InvalidInput { - source: "must provide training request created by new_training_request".into(), - location: location!(), - })?; - Self::train_inverted_index( - data, - index_store, - request.parameters.clone(), - fragment_ids, - noop_progress(), - ) - .await - } - - async fn train_index_with_progress( - &self, - data: SendableRecordBatchStream, - index_store: &dyn IndexStore, - request: Box, - fragment_ids: Option>, progress: Arc, ) -> Result { let request = (request as Box) diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 1e7e627b28f..17f9645535e 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -171,15 +171,18 @@ impl InvertedIndexBuilder { let with_position = self.params.with_position; let next_id = self.partitions.iter().map(|id| id + 1).max().unwrap_or(0); let id_alloc = Arc::new(AtomicU64::new(next_id)); + let tokenized_count = Arc::new(AtomicU64::new(0)); let (sender, receiver) = async_channel::bounded(num_workers); let mut index_tasks = Vec::with_capacity(num_workers); for _ in 0..num_workers { let store = self.local_store.clone(); let tokenizer = tokenizer.clone(); - let receiver = receiver.clone(); + let receiver: async_channel::Receiver = receiver.clone(); let id_alloc = id_alloc.clone(); + let progress = self.progress.clone(); let fragment_mask = self.fragment_mask; let token_set_format = self.token_set_format; + let tokenized_count = tokenized_count.clone(); let task = tokio::task::spawn(async move { let mut worker = IndexWorker::new( store, @@ -191,7 +194,14 @@ impl InvertedIndexBuilder { ) .await?; while let Ok(batch) = receiver.recv().await { + let num_rows = batch.num_rows(); worker.process_batch(batch).await?; + let tokenized_count = tokenized_count + .fetch_add(num_rows as u64, std::sync::atomic::Ordering::Relaxed) + + num_rows as u64; + progress + .stage_progress("tokenize_docs", tokenized_count) + .await?; } let partitions = worker.finish().await?; Result::Ok(partitions) @@ -221,9 +231,6 @@ impl InvertedIndexBuilder { while let Some(num_rows) = stream.try_next().await? { total_num_rows += num_rows; if total_num_rows >= last_num_rows + 1_000_000 { - self.progress - .stage_progress("tokenize_docs", total_num_rows as u64) - .await?; log::debug!( "indexed {} documents, elapsed: {:?}, speed: {}rows/s", total_num_rows, @@ -233,11 +240,6 @@ impl InvertedIndexBuilder { last_num_rows = total_num_rows; } } - if total_num_rows > last_num_rows { - self.progress - .stage_progress("tokenize_docs", total_num_rows as u64) - .await?; - } // drop the sender to stop receivers drop(stream); debug_assert_eq!(sender.sender_count(), 1); @@ -1613,17 +1615,13 @@ mod tests { Arc::new(LanceCache::no_cache()), )); - let schema = Arc::new(Schema::new(vec![ - Field::new("doc", DataType::Utf8, true), - Field::new(ROW_ID, DataType::UInt64, false), - ])); - let docs = Arc::new(StringArray::from(vec![ - Some("hello world"), - Some("goodbye world"), - ])); - let row_ids = Arc::new(UInt64Array::from(vec![0u64, 1u64])); - let batch = RecordBatch::try_new(schema.clone(), vec![docs, row_ids])?; - let stream = RecordBatchStreamAdapter::new(schema, stream::iter(vec![Ok(batch)])); + let batch1 = make_doc_batch("hello world", 0); + let batch2 = make_doc_batch("goodbye world", 1); + let total_rows = 2u64; + let stream = RecordBatchStreamAdapter::new( + batch1.schema(), + stream::iter(vec![Ok(batch1), Ok(batch2)]), + ); let stream = Box::pin(stream); let progress = Arc::new(RecordingProgress::default()); @@ -1637,6 +1635,16 @@ mod tests { .iter() .map(|(kind, stage, _)| format!("{kind}:{stage}")) .collect::>(); + let tokenize_progress = events + .iter() + .filter_map(|(kind, stage, completed)| { + if kind == "progress" && stage == "tokenize_docs" { + Some(*completed) + } else { + None + } + }) + .collect::>(); let tokenize_start = tags .iter() @@ -1673,6 +1681,15 @@ mod tests { tags.iter().any(|e| e == "progress:tokenize_docs"), "expected progress callback for tokenize_docs" ); + assert!( + tokenize_progress.len() >= 2, + "expected at least two progress callbacks for tokenize_docs, got {tokenize_progress:?}" + ); + assert_eq!( + tokenize_progress.iter().copied().max().unwrap_or_default(), + total_rows, + "expected tokenize_docs progress to reach all rows" + ); assert!( tags.iter().any(|e| e == "progress:copy_partitions"), "expected progress callback for copy_partitions" diff --git a/rust/lance-index/src/scalar/json.rs b/rust/lance-index/src/scalar/json.rs index 82501444291..b985d903b11 100644 --- a/rust/lance-index/src/scalar/json.rs +++ b/rust/lance-index/src/scalar/json.rs @@ -776,6 +776,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { index_store: &dyn IndexStore, request: Box, fragment_ids: Option>, + progress: Arc, ) -> Result { let request = (request as Box) .downcast::() @@ -805,7 +806,13 @@ impl ScalarIndexPlugin for JsonIndexPlugin { )?; let target_index = target_plugin - .train_index(converted_stream, index_store, target_request, fragment_ids) + .train_index( + converted_stream, + index_store, + target_request, + fragment_ids, + progress, + ) .await?; let index_details = crate::pb::JsonIndexDetails { diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index e971c45fa97..d4db6f403e2 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -414,6 +414,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { index_store: &dyn IndexStore, request: Box, fragment_ids: Option>, + progress: Arc, ) -> Result { if fragment_ids.is_some() { return Err(Error::InvalidInput { @@ -450,7 +451,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { let data = unnest_chunks(data)?; let bitmap_plugin = BitmapIndexPlugin; bitmap_plugin - .train_index(data, index_store, request, fragment_ids) + .train_index(data, index_store, request, fragment_ids, progress) .await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pbold::LabelListIndexDetails::default()) diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index cdb3f73db84..280044e9a69 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -364,7 +364,13 @@ pub mod tests { ) .unwrap(); btree_plugin - .train_index(data, index_store.as_ref(), request, None) + .train_index( + data, + index_store.as_ref(), + request, + None, + crate::progress::noop_progress(), + ) .await .unwrap(); } @@ -907,7 +913,13 @@ pub mod tests { .new_training_request("{}", &Field::new(VALUE_COLUMN_NAME, DataType::Int32, false)) .unwrap(); BitmapIndexPlugin - .train_index(data, index_store.as_ref(), request, None) + .train_index( + data, + index_store.as_ref(), + request, + None, + crate::progress::noop_progress(), + ) .await .unwrap(); } @@ -1395,7 +1407,13 @@ pub mod tests { ) .unwrap(); LabelListIndexPlugin - .train_index(data, index_store.as_ref(), request, None) + .train_index( + data, + index_store.as_ref(), + request, + None, + crate::progress::noop_progress(), + ) .await .unwrap(); } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index 4d4c0bfeef2..f61b79a5ab8 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -1296,6 +1296,7 @@ impl ScalarIndexPlugin for NGramIndexPlugin { index_store: &dyn IndexStore, _request: Box, fragment_ids: Option>, + _progress: Arc, ) -> Result { if fragment_ids.is_some() { return Err(Error::InvalidInput { diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 452569f92b3..4f657e201a9 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -115,29 +115,8 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { index_store: &dyn IndexStore, request: Box, fragment_ids: Option>, - ) -> Result; - - /// Train a new index with progress callbacks. - /// - /// Plugins can override this for index-specific stages. The default implementation - /// wraps [`Self::train_index`] with a generic scalar training stage. - async fn train_index_with_progress( - &self, - data: SendableRecordBatchStream, - index_store: &dyn IndexStore, - request: Box, - fragment_ids: Option>, progress: Arc, - ) -> Result { - progress.stage_start("train_scalar", None, "").await?; - let result = self - .train_index(data, index_store, request, fragment_ids) - .await; - if result.is_ok() { - progress.stage_complete("train_scalar").await?; - } - result - } + ) -> Result; /// A short name for the index /// diff --git a/rust/lance-index/src/scalar/rtree.rs b/rust/lance-index/src/scalar/rtree.rs index 3f36ee399ab..4b54afc2f3b 100644 --- a/rust/lance-index/src/scalar/rtree.rs +++ b/rust/lance-index/src/scalar/rtree.rs @@ -882,6 +882,7 @@ impl ScalarIndexPlugin for RTreeIndexPlugin { index_store: &dyn IndexStore, request: Box, fragment_ids: Option>, + _progress: Arc, ) -> Result { if fragment_ids.is_some() { return Err(Error::InvalidInput { @@ -1021,6 +1022,7 @@ mod tests { page_size: Some(page_size), })), None, + crate::progress::noop_progress(), ) .await .unwrap(); diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index b631ba89d48..6f95cd05ab6 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -899,6 +899,7 @@ impl ScalarIndexPlugin for ZoneMapIndexPlugin { index_store: &dyn IndexStore, request: Box, fragment_ids: Option>, + _progress: Arc, ) -> Result { if fragment_ids.is_some() { return Err(Error::InvalidInput { diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 1c04bfdd96f..76908cc5f88 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -293,7 +293,7 @@ pub(super) async fn build_scalar_index( progress.stage_complete("load_data").await?; plugin - .train_index_with_progress( + .train_index( training_data, &index_store, training_request,