From 88a7a37eca6d8fd00fc3d689f65603f0b846aec2 Mon Sep 17 00:00:00 2001 From: xloya Date: Fri, 5 Sep 2025 12:48:21 +0800 Subject: [PATCH 01/13] support btree distributely --- java/.gitignore | 3 + python/python/lance/dataset.py | 36 +- python/python/lance/lance/__init__.pyi | 4 +- python/python/tests/test_scalar_index.py | 416 +++- python/src/dataset.rs | 126 +- rust/lance-index/src/scalar/bitmap.rs | 1 + rust/lance-index/src/scalar/btree.rs | 2119 ++++++++++++++++--- rust/lance-index/src/scalar/inverted.rs | 4 +- rust/lance-index/src/scalar/json.rs | 3 +- rust/lance-index/src/scalar/label_list.rs | 3 +- rust/lance-index/src/scalar/lance_format.rs | 7 +- rust/lance-index/src/scalar/ngram.rs | 1 + rust/lance-index/src/scalar/registry.rs | 1 + rust/lance-index/src/scalar/zonemap.rs | 1 + rust/lance/benches/scalar_index.rs | 1 + rust/lance/src/index/scalar.rs | 5 +- 16 files changed, 2429 insertions(+), 302 deletions(-) diff --git a/java/.gitignore b/java/.gitignore index d9074bd2835..f134c3c1a74 100644 --- a/java/.gitignore +++ b/java/.gitignore @@ -1,2 +1,5 @@ *.iml .java-version +.project +.settings +.classpath \ No newline at end of file diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 376476e43af..54bd432201a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2731,8 +2731,40 @@ def prewarm_index(self, name: str): """ return self._ds.prewarm_index(name) - def merge_index_metadata(self, index_uuid: str): - return self._ds.merge_index_metadata(index_uuid) + def merge_index_metadata( + self, + index_uuid: str, + index_type: Union[ + Literal["BTREE"], + Literal["INVERTED"], + ], + prefetch_batch: Optional[int] = None, + ): + """ + Merge an index which not commit at present. + + Parameters + ---------- + index_uuid: str + The uuid of the index which want to merge. + index_type: Literal["BTREE", "INVERTED"] + The type of the index. + prefetch_batch: int, optional + The number of prefetch batches of sub-page files for merging. + Default 1. + """ + index_type = index_type.upper() + if index_type not in [ + "BTREE", + "INVERTED", + ]: + raise NotImplementedError( + ( + 'Only "BTREE" or "INVERTED" are supported for ' + f"merge index metadata. Received {index_type}", + ) + ) + return self._ds.merge_index_metadata(index_uuid, index_type, prefetch_batch) def session(self) -> Session: """ diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index c2a72b7b1b5..0bae8e2f1aa 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -282,7 +282,9 @@ class _Dataset: ): ... def drop_index(self, name: str): ... def prewarm_index(self, name: str): ... - def merge_index_metadata(self, index_uuid: str): ... + def merge_index_metadata( + self, index_uuid: str, index_type: str, prefetch_batch: Optional[int] = None + ): ... def count_fragments(self) -> int: ... def num_small_files(self, max_rows_per_group: int) -> int: ... def get_fragments(self) -> List[_Fragment]: ... diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index c2370a17a9e..5f92a1e4d11 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -1982,7 +1982,7 @@ def build_distributed_fts_index( ) # Merge the inverted index metadata - dataset.merge_index_metadata(index_id) + dataset.merge_index_metadata(index_id, index_type="INVERTED") # Create Index object for commit field_id = dataset.schema.get_field_index(column) @@ -2856,7 +2856,7 @@ def test_distribute_fts_index_build(tmp_path): print(f"Fragment {fragment_id} index created successfully") # Merge the inverted index metadata - ds.merge_index_metadata(index_id) + ds.merge_index_metadata(index_id, index_type="INVERTED") # Create an Index object using the new dataclass format from lance.dataset import Index @@ -2983,3 +2983,415 @@ def test_backward_compatibility_no_fragment_ids(tmp_path): results = ds.scanner(full_text_query=search_word).to_table() assert results.num_rows > 0 + + +def test_distribute_btree_index_build(tmp_path): + """ + Test distributed B-tree index build similar to test_distribute_fts_index_build. + This test creates B-tree indices on individual fragments and then + commits them as a single index. + """ + # Generate test dataset with multiple fragments + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=4, rows_per_fragment=10000 + ) + + import uuid + + index_id = str(uuid.uuid4()) + print(f"Using index ID: {index_id}") + index_name = "btree_multiple_fragment_idx" + + fragments = ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + for fragment in ds.get_fragments(): + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + # Create B-tree scalar index for each fragment + # Use the same index_name for all fragments (like in FTS test) + ds.create_scalar_index( + column="id", # Use integer column for B-tree + index_type="BTREE", + name=index_name, + replace=False, + fragment_uuid=index_id, + fragment_ids=[fragment_id], + ) + + # For fragment-level indexing, we expect the method to return successfully + # but not commit the index yet + print(f"Fragment {fragment_id} B-tree index created successfully") + + # Merge the B-tree index metadata + ds.merge_index_metadata(index_id, index_type="BTREE") + print(ds.uri) + + # Create an Index object using the new dataclass format + from lance.dataset import Index + + # Get the schema field for the indexed column + field_id = ds.schema.get_field_index("id") + + index = Index( + uuid=index_id, + name=index_name, + fields=[field_id], # Use field index instead of field object + dataset_version=ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Create the index operation + create_index_op = lance.LanceOperation.CreateIndex( + new_indices=[index], + removed_indices=[], + ) + + # Commit the index + ds_committed = lance.LanceDataset.commit( + ds.uri, + create_index_op, + read_version=ds.version, + ) + + print("Successfully committed multiple fragment B-tree index") + + # Verify the index was created and is functional + indices = ds_committed.list_indices() + assert len(indices) > 0, "No indices found after commit" + + # Find our index + our_index = None + for idx in indices: + if idx["name"] == index_name: + our_index = idx + break + + assert our_index is not None, f"Index '{index_name}' not found in indices list" + assert our_index["type"] == "BTree", ( + f"Expected BTree index, got {our_index['type']}" + ) + + # Test that the index works for searching + # Test exact equality queries + test_id = 100 # Should be in first fragment + results = ds_committed.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + print(f"Search for id = {test_id} returned {results.num_rows} results") + assert results.num_rows > 0, f"No results found for id = {test_id}" + + # Test range queries across fragments + results_range = ds_committed.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + print(f"Range query returned {results_range.num_rows} results") + assert results_range.num_rows > 0, "No results found for range query" + + # Compare with complete index results to ensure consistency + # Create a reference dataset with complete index + reference_ds = generate_multi_fragment_dataset( + tmp_path / "reference", num_fragments=4, rows_per_fragment=10000 + ) + + # Create complete B-tree index for comparison + reference_ds.create_scalar_index( + column="id", + index_type="BTREE", + name="reference_btree_idx", + ) + + # Compare exact query results + reference_results = reference_ds.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + assert results.num_rows == reference_results.num_rows, ( + f"Distributed index returned {results.num_rows} results, " + f"but complete index returned {reference_results.num_rows} results" + ) + + # Compare range query results + reference_range_results = reference_ds.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + assert results_range.num_rows == reference_range_results.num_rows, ( + f"Distributed index range query returned {results_range.num_rows} results, " + f"but complete index returned {reference_range_results.num_rows} results" + ) + + +def test_btree_precise_query_comparison(tmp_path): + """ + Precise comparison test between fragment-level B-tree index and complete + B-tree index. + This test creates identical datasets and compares query results in detail. + """ + # Test configuration + num_fragments = 3 + rows_per_fragment = 10000 + total_rows = num_fragments * rows_per_fragment + + print( + f"Creating datasets with {num_fragments} fragments," + f" {rows_per_fragment} rows each" + ) + + # Create dataset for fragment-level indexing + fragment_ds = generate_multi_fragment_dataset( + tmp_path / "fragment", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + # Create dataset for complete indexing (same data structure) + complete_ds = generate_multi_fragment_dataset( + tmp_path / "complete", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + import uuid + + # Build fragment-level B-tree index + fragment_index_id = str(uuid.uuid4()) + fragment_index_name = "fragment_btree_precise_test" + + fragments = fragment_ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + # Create fragment-level indices + for fragment in fragments: + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + fragment_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=fragment_index_name, + replace=False, + fragment_uuid=fragment_index_id, + fragment_ids=[fragment_id], + ) + + # Merge fragment indices + fragment_ds.merge_index_metadata(fragment_index_id, index_type="BTREE") + + # Create Index object for fragment-based index + from lance.dataset import Index + + field_id = fragment_ds.schema.get_field_index("id") + + fragment_index = Index( + uuid=fragment_index_id, + name=fragment_index_name, + fields=[field_id], + dataset_version=fragment_ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Commit fragment-based index + create_fragment_index_op = lance.LanceOperation.CreateIndex( + new_indices=[fragment_index], + removed_indices=[], + ) + + fragment_ds_committed = lance.LanceDataset.commit( + fragment_ds.uri, + create_fragment_index_op, + read_version=fragment_ds.version, + ) + + # Build complete B-tree index + complete_index_name = "complete_btree_precise_test" + complete_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=complete_index_name, + ) + + print("Both indices created successfully") + + # Detailed query comparison tests + test_cases = [ + # Test 1: Boundary values at fragment edges + {"name": "First value", "filter": "id = 0"}, + {"name": "Fragment 0 last value", "filter": f"id = {rows_per_fragment - 1}"}, + {"name": "Fragment 1 first value", "filter": f"id = {rows_per_fragment}"}, + { + "name": "Fragment 1 last value", + "filter": f"id = {2 * rows_per_fragment - 1}", + }, + {"name": "Fragment 2 first value", "filter": f"id = {2 * rows_per_fragment}"}, + {"name": "Last value", "filter": f"id = {total_rows - 1}"}, + # Test 2: Values in the middle of fragments + {"name": "Fragment 0 middle", "filter": f"id = {rows_per_fragment // 2}"}, + { + "name": "Fragment 1 middle", + "filter": f"id = {rows_per_fragment + rows_per_fragment // 2}", + }, + { + "name": "Fragment 2 middle", + "filter": f"id = {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 3: Range queries within single fragments + {"name": "Range within fragment 0", "filter": "id >= 10 AND id < 20"}, + { + "name": "Range within fragment 1", + "filter": f"id >= {rows_per_fragment + 10}" + f" AND id < {rows_per_fragment + 20}", + }, + { + "name": "Range within fragment 2", + "filter": f"id >= {2 * rows_per_fragment + 10}" + f" AND id < {2 * rows_per_fragment + 20}", + }, + # Test 4: Range queries spanning multiple fragments + { + "name": "Cross fragment 0-1", + "filter": f"id >= {rows_per_fragment - 5} AND id < {rows_per_fragment + 5}", + }, + { + "name": "Cross fragment 1-2", + "filter": f"id >= {2 * rows_per_fragment - 5}" + f" AND id < {2 * rows_per_fragment + 5}", + }, + { + "name": "Cross all fragments", + "filter": f"id >= {rows_per_fragment // 2} AND" + f" id < {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 5: Edge cases + {"name": "Non-existent small value", "filter": "id = -1"}, + {"name": "Non-existent large value", "filter": f"id = {total_rows + 100}"}, + {"name": "Large range", "filter": f"id >= 0 AND id < {total_rows}"}, + # Test 6: Comparison operators + {"name": "Less than boundary", "filter": f"id < {rows_per_fragment}"}, + { + "name": "Greater than boundary", + "filter": f"id > {2 * rows_per_fragment - 1}", + }, + {"name": "Less than or equal", "filter": f"id <= {rows_per_fragment + 50}"}, + {"name": "Greater than or equal", "filter": f"id >= {rows_per_fragment + 50}"}, + ] + + print(f"\nRunning {len(test_cases)} detailed comparison tests:") + + for i, test_case in enumerate(test_cases, 1): + test_name = test_case["name"] + filter_expr = test_case["filter"] + + print(f" {i:2d}. Testing {test_name}: {filter_expr}") + + # Query fragment-based index + fragment_results = fragment_ds_committed.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Query complete index + complete_results = complete_ds.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Compare row counts + assert fragment_results.num_rows == complete_results.num_rows, ( + f"Test '{test_name}' failed: Fragment index " + f"returned {fragment_results.num_rows} rows, " + f"but complete index returned {complete_results.num_rows}" + f" rows for filter: {filter_expr}" + ) + + # Compare actual results if there are any + if fragment_results.num_rows > 0: + # Sort both results by id for comparison + fragment_ids = sorted(fragment_results.column("id").to_pylist()) + complete_ids = sorted(complete_results.column("id").to_pylist()) + + assert fragment_ids == complete_ids, ( + f"Test '{test_name}' failed: Fragment index" + f" returned different IDs than complete index. " + f"Fragment IDs:" + f" {fragment_ids[:10]}{'...' if len(fragment_ids) > 10 else ''}, " + f"Complete IDs:" + f" {complete_ids[:10]}{'...' if len(complete_ids) > 10 else ''}" + ) + + print(f" āœ“ Passed ({fragment_results.num_rows} rows)") + + print(f"\nāœ… All {len(test_cases)} precision tests passed!") + print( + "Fragment-level B-tree index produces identical results" + " to complete B-tree index." + ) + + +def test_btree_fragment_ids_parameter_validation(tmp_path): + """ + Test validation of fragment_ids parameter for B-tree indices. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # Test with valid fragment IDs + fragments = ds.get_fragments() + valid_fragment_id = fragments[0].fragment_id + + # This should work without errors + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[valid_fragment_id], + ) + + # Test with invalid fragment ID (should handle gracefully) + try: + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[999999], # Non-existent fragment ID + ) + except Exception as e: + # It's acceptable for this to fail with an appropriate error + print(f"Expected error for invalid fragment ID: {e}") + + +def test_btree_backward_compatibility_no_fragment_ids(tmp_path): + """ + Test that B-tree indexing remains backward compatible + when fragment_ids is not provided. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # This should work exactly as before (full dataset indexing) + ds.create_scalar_index( + column="id", + index_type="BTREE", + name="full_dataset_btree_idx", + ) + + # Verify the index was created + indices = ds.list_indices() + assert len(indices) == 1 + assert indices[0]["name"] == "full_dataset_btree_idx" + assert indices[0]["type"] == "BTree" + + # Test that the index works + results = ds.scanner(filter="id = 50").to_table() + assert results.num_rows > 0 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 83da7f26dc9..f3d0fd10e83 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1670,47 +1670,111 @@ impl Dataset { .infer_error() } - #[pyo3(signature = (index_uuid))] - fn merge_index_metadata(&self, index_uuid: &str) -> PyResult<()> { + #[pyo3(signature = (index_uuid, index_type, prefetch_batch))] + fn merge_index_metadata( + &self, + index_uuid: &str, + index_type: &str, + prefetch_batch: Option, + ) -> PyResult<()> { RT.block_on(None, async { + let index_type = index_type.to_uppercase(); + let idx_type = match index_type.as_str() { + "BTREE" => IndexType::BTree, + "INVERTED" => IndexType::Inverted, + _ => { + return Err(Error::InvalidInput { + source: format!( + "Index type {} is not supported.", + index_type + ).into(), + location: location!(), + }); + } + }; + let store = LanceIndexStore::from_dataset_for_new(self.ds.as_ref(), index_uuid)?; let index_dir = self.ds.indices_dir().child(index_uuid); + if idx_type == IndexType::Inverted { + // List all partition metadata files in the index directory + let mut part_metadata_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") + { + part_metadata_files.push(file_name.to_string()); + } + } + Err(_) => continue, + } + } + + if part_metadata_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {}", + index_dir + ) + .into(), + location: location!(), + }); + } - // List all partition metadata files in the index directory - let mut part_metadata_files = Vec::new(); - let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); - - while let Some(item) = list_stream.next().await { - match item { - Ok(meta) => { - let file_name = meta.location.filename().unwrap_or_default(); - // Filter files matching the pattern part_*_metadata.lance - if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") - { - part_metadata_files.push(file_name.to_string()); + // Call merge_metadata_files function for inverted index + lance_index::scalar::inverted::builder::merge_metadata_files( + Arc::new(store), + &part_metadata_files, + ) + .await + } else { + // List all partition page / lookup files in the index directory + let mut part_page_files = Vec::new(); + let mut part_lookup_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_page_data.lance") + { + part_page_files.push(file_name.to_string()); + } + if file_name.starts_with("part_") && file_name.ends_with("_page_lookup.lance") + { + part_lookup_files.push(file_name.to_string()); + } } + Err(_) => continue, } - Err(_) => continue, } - } + if part_page_files.is_empty() || part_lookup_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {} (page_files: {}, lookup_files: {})", + index_dir, part_page_files.len(), part_lookup_files.len() + ) + .into(), + location: location!(), + }); + } - if part_metadata_files.is_empty() { - return Err(Error::InvalidInput { - source: format!( - "No partition metadata files found in index directory: {}", - index_dir - ) - .into(), - location: location!(), - }); + // Call merge_metadata_files function for btree index + lance_index::scalar::btree::merge_metadata_files( + Arc::new(store), + &part_page_files, + &part_lookup_files, + prefetch_batch, + ).await } - // Call merge_metadata_files function for inverted index - lance_index::scalar::inverted::builder::merge_metadata_files( - Arc::new(store), - &part_metadata_files, - ) - .await + })? .map_err(|err| PyValueError::new_err(err.to_string())) } diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 1b5d0d530bd..09dde5297b4 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -528,6 +528,7 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_bitmap_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index b2760fe214d..d48efeebb18 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -4,10 +4,10 @@ use std::{ any::Any, cmp::Ordering, - collections::{BTreeMap, BinaryHeap, HashMap}, + collections::{BTreeMap, BinaryHeap, HashMap, VecDeque}, fmt::{Debug, Display}, ops::Bound, - sync::Arc, + sync::{Arc, LazyLock}, }; use super::{ @@ -38,7 +38,7 @@ use deepsize::DeepSizeOf; use futures::{ future::BoxFuture, stream::{self}, - FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, + Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, }; use lance_core::{ cache::{CacheKey, LanceCache}, @@ -54,16 +54,37 @@ use lance_datafusion::{ chunker::chunk_concat_stream, exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; -use log::debug; +use log::{debug, warn}; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize, Serializer}; use snafu::location; +use tokio::runtime::{Builder, Runtime}; use tracing::info; const BTREE_LOOKUP_NAME: &str = "page_lookup.lance"; const BTREE_PAGES_NAME: &str = "page_data.lance"; pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; const BATCH_SIZE_META_KEY: &str = "batch_size"; + +/// Global thread pool for B-tree prefetch operations +static BTREE_PREFETCH_RUNTIME: LazyLock = LazyLock::new(|| { + Builder::new_multi_thread() + .worker_threads(get_num_compute_intensive_cpus()) + .max_blocking_threads(get_num_compute_intensive_cpus()) + .thread_name("lance-btree-prefetch") + .enable_time() + .build() + .expect("Failed to create B-tree prefetch runtime") +}); + +/// Spawn a prefetch task on the B-tree thread pool +fn spawn_btree_prefetch(future: F) -> tokio::task::JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + BTREE_PREFETCH_RUNTIME.spawn(future) +} const BTREE_INDEX_VERSION: u32 = 0; pub(crate) const BTREE_VALUES_COLUMN: &str = "values"; pub(crate) const BTREE_IDS_COLUMN: &str = "ids"; @@ -1231,6 +1252,7 @@ impl ScalarIndex for BTreeIndex { self.sub_index.as_ref(), dest_store, DEFAULT_BTREE_BATCH_SIZE, + None, ) .await?; @@ -1366,10 +1388,33 @@ pub async fn train_btree_index( sub_index_trainer: &dyn BTreeSubIndex, index_store: &dyn IndexStore, batch_size: u64, + fragment_ids: Option>, ) -> Result<()> { - let mut sub_index_file = index_store - .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) - .await?; + let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| { + if !frag_ids.is_empty() { + // Create a mask with fragment_id in high 32 bits for distributed indexing + // This mask is used to filter partitions belonging to specific fragments + // If multiple fragments processed, use first fragment_id <<32 as mask + Some((frag_ids[0] as u64) << 32) + } else { + None + } + }); + + let mut sub_index_file; + if fragment_mask.is_none() { + sub_index_file = index_store + .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) + .await?; + } else { + sub_index_file = index_store + .new_index_file( + part_page_data_file_path(fragment_mask.unwrap()).as_str(), + sub_index_trainer.schema().clone(), + ) + .await?; + } + let mut encoded_batches = Vec::new(); let mut batch_idx = 0; @@ -1393,170 +1438,1053 @@ pub async fn train_btree_index( file_schema .metadata .insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); - let mut btree_index_file = index_store - .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) - .await?; + let mut btree_index_file; + if fragment_mask.is_none() { + btree_index_file = index_store + .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) + .await?; + } else { + btree_index_file = index_store + .new_index_file( + part_lookup_file_path(fragment_mask.unwrap()).as_str(), + Arc::new(file_schema), + ) + .await?; + } btree_index_file.write_record_batch(record_batch).await?; btree_index_file.finish().await?; Ok(()) } -/// A stream that reads the original training data back out of the index -/// -/// This is used for updating the index -struct IndexReaderStream { - reader: Arc, - batch_size: u64, - num_batches: u32, - batch_idx: u32, +/// Extract partition ID from partition file name +/// Expected format: "part_{partition_id}_{suffix}.lance" +fn extract_partition_id(filename: &str) -> Result { + if !filename.starts_with("part_") { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + let parts: Vec<&str> = filename.split('_').collect(); + if parts.len() < 3 { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + parts[1].parse::().map_err(|_| Error::Internal { + message: format!("Failed to parse partition ID from filename: {}", filename), + location: location!(), + }) } -impl IndexReaderStream { - async fn new(reader: Arc, batch_size: u64) -> Self { - let num_batches = reader.num_batches(batch_size).await; - Self { - reader, - batch_size, - num_batches, - batch_idx: 0, +/// Merge multiple partition page / lookup files into a complete metadata file +/// +/// In a distributed environment, each worker node writes partition page / lookup files for the partitions it processes, +/// and this function merges these files into a final metadata file. +pub async fn merge_metadata_files( + store: Arc, + part_page_files: &[String], + part_lookup_files: &[String], + prefetch_batch: Option, +) -> Result<()> { + if part_lookup_files.is_empty() || part_page_files.is_empty() { + return Err(Error::Internal { + message: "No partition files provided for merging".to_string(), + location: location!(), + }); + } + + // Step 1: Create lookup map for page files by partition ID + let mut page_files_map = HashMap::new(); + for page_file in part_page_files { + let partition_id = extract_partition_id(page_file)?; + page_files_map.insert(partition_id, page_file); + } + + // Step 2: Validate that all lookup files have corresponding page files + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + if !page_files_map.contains_key(&partition_id) { + return Err(Error::Internal { + message: format!( + "No corresponding page file found for lookup file: {} (partition_id: {})", + lookup_file, partition_id + ), + location: location!(), + }); } } + + // Step 3: Extract metadata from lookup files + let first_lookup_reader = store.open_index_file(&part_lookup_files[0]).await?; + let batch_size = first_lookup_reader + .schema() + .metadata + .get(BATCH_SIZE_META_KEY) + .map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE)) + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE); + + // Get the value type from lookup schema (min column) + let lookup_batch = first_lookup_reader.read_range(0..1, None).await?; + let value_type = lookup_batch.column(0).data_type().clone(); + + // Get page schema first + let partition_id = extract_partition_id(part_lookup_files[0].as_str())?; + let page_file = page_files_map.get(&partition_id).unwrap(); + let page_reader = store.open_index_file(page_file).await?; + let page_schema = page_reader.schema().clone(); + + let arrow_schema = Arc::new(Schema::from(&page_schema)); + let mut page_file = store + .new_index_file(BTREE_PAGES_NAME, arrow_schema.clone()) + .await?; + + let mut prefetch_config = PrefetchConfig::default(); + if prefetch_batch.is_some() { + prefetch_config = prefetch_config.with_prefetch_batch(prefetch_batch.unwrap()); + } + + let lookup_entries = merge_page( + part_lookup_files, + &page_files_map, + &store, + batch_size, + &mut page_file, + arrow_schema.clone(), + prefetch_config, + ) + .await?; + + page_file.finish().await?; + + // Step 4: Generate new lookup file based on reorganized pages + // Add batch_size to schema metadata + let mut metadata = HashMap::new(); + metadata.insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); + + let lookup_schema_with_metadata = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("min", value_type.clone(), true), + Field::new("max", value_type, true), + Field::new("null_count", DataType::UInt32, false), + Field::new("page_idx", DataType::UInt32, false), + ], + metadata, + )); + + let lookup_batch = RecordBatch::try_new( + lookup_schema_with_metadata.clone(), + vec![ + ScalarValue::iter_to_array(lookup_entries.iter().map(|(min, _, _, _)| min.clone()))?, + ScalarValue::iter_to_array(lookup_entries.iter().map(|(_, max, _, _)| max.clone()))?, + Arc::new(UInt32Array::from_iter_values( + lookup_entries + .iter() + .map(|(_, _, null_count, _)| *null_count), + )), + Arc::new(UInt32Array::from_iter_values( + lookup_entries.iter().map(|(_, _, _, page_idx)| *page_idx), + )), + ], + )?; + + let mut lookup_file = store + .new_index_file(BTREE_LOOKUP_NAME, lookup_schema_with_metadata) + .await?; + lookup_file.write_record_batch(lookup_batch).await?; + lookup_file.finish().await?; + + // After successfully writing the merged files, delete all partition files + // Only perform deletion after files are successfully written, ensuring debug information is not lost in case of failure + cleanup_partition_files(&store, part_lookup_files, part_page_files).await; + + Ok(()) } -impl Stream for IndexReaderStream { - type Item = BoxFuture<'static, Result>; +/// Clean up partition files after successful merge +/// +/// This function safely deletes partition lookup and page files after a successful merge operation. +/// File deletion failures are logged but do not affect the overall success of the merge operation. +async fn cleanup_partition_files( + store: &Arc, + part_lookup_files: &[String], + part_page_files: &[String], +) { + // Clean up partition lookup files + for file_name in part_lookup_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_lookup.lance", + "partition lookup", + ) + .await; + } - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - if this.batch_idx >= this.num_batches { - return std::task::Poll::Ready(None); - } - let batch_num = this.batch_idx; - this.batch_idx += 1; - let reader_copy = this.reader.clone(); - let batch_size = this.batch_size; - let read_task = async move { - reader_copy - .read_record_batch(batch_num as u64, batch_size) - .await + // Clean up partition page files + for file_name in part_page_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_data.lance", + "partition page", + ) + .await; + } +} + +/// Helper function to clean up a single partition file +/// +/// Performs safety checks on the filename pattern before attempting deletion. +async fn cleanup_single_file( + store: &Arc, + file_name: &str, + expected_prefix: &str, + expected_suffix: &str, + file_type: &str, +) { + // Ensure we only delete files that match the expected pattern (safety check) + if file_name.starts_with(expected_prefix) && file_name.ends_with(expected_suffix) { + match store.delete_index_file(file_name).await { + Ok(()) => { + debug!("Successfully deleted {} file: {}", file_type, file_name); + } + Err(e) => { + // File deletion failures should not affect the overall success of the function + // Log the error but continue processing other files + warn!( + "Failed to delete {} file '{}': {}. \ + This does not affect the merge operation, but may leave \ + partition files that should be cleaned up manually.", + file_type, file_name, e + ); + } } - .boxed(); - std::task::Poll::Ready(Some(read_task)) + } else { + // If the filename doesn't match the expected format, log a warning but don't attempt deletion + warn!( + "Skipping deletion of file '{}' as it does not match the expected \ + {} file pattern ({}*{})", + file_name, file_type, expected_prefix, expected_suffix + ); } } -/// Parameters for a btree index -#[derive(Debug, Serialize, Deserialize)] -pub struct BTreeParameters { - /// The number of rows to include in each zone - pub zone_size: Option, +/// Prefetch configuration for partition iterators +#[derive(Debug, Clone)] +pub struct PrefetchConfig { + /// Number of batches to prefetch ahead (0 means no prefetching) + pub prefetch_batches: usize, } -struct BTreeTrainingRequest { - parameters: BTreeParameters, - criteria: TrainingCriteria, +/// Buffer entry for prefetch queue +#[derive(Debug)] +struct BufferEntry { + batch: RecordBatch, + start_row: usize, + end_row: usize, } -impl BTreeTrainingRequest { - pub fn new(parameters: BTreeParameters) -> Self { +/// Running prefetch task information +#[derive(Debug)] +struct RunningPrefetchTask { + /// Task handle + handle: tokio::task::JoinHandle<()>, + /// Range being prefetched + range: std::ops::Range, +} + +/// Check if two ranges overlap +fn ranges_overlap(range1: &std::ops::Range, range2: &std::ops::Range) -> bool { + range1.start < range2.end && range2.start < range1.end +} + +/// Prefetch state for a partition using task-based prefetching +struct PartitionPrefetchState { + /// Queue of prefetched data + buffer: Arc>>, + /// Reader for this partition + reader: Arc, + /// Total rows in this partition + total_rows: usize, + /// Queue of running prefetch tasks with their ranges + running_tasks: Arc>>, + /// Next position to schedule for prefetch + next_prefetch_position: Arc>, +} + +/// Manager for coordinating task-based prefetch across multiple partitions +pub struct PrefetchManager { + /// Prefetch state per partition + partition_states: HashMap, + /// Prefetch configuration + config: PrefetchConfig, +} + +impl PrefetchManager { + /// Create a new prefetch manager + pub fn new(config: PrefetchConfig) -> Self { Self { - parameters, - // BTree indexes need data sorted by the value column - criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + partition_states: HashMap::new(), + config, } } -} -impl TrainingRequest for BTreeTrainingRequest { - fn as_any(&self) -> &dyn std::any::Any { - self - } + /// Initialize a partition for task-based prefetching + pub fn initialize_partition(&mut self, partition_id: u64, reader: Arc) { + let total_rows = reader.num_rows(); + let buffer = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let running_tasks = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let next_prefetch_position = Arc::new(tokio::sync::Mutex::new(0)); - fn criteria(&self) -> &TrainingCriteria { - &self.criteria + let state = PartitionPrefetchState { + buffer, + reader, + total_rows, + running_tasks, + next_prefetch_position, + }; + + self.partition_states.insert(partition_id, state); + debug!( + "Initialized partition {} for task-based prefetching", + partition_id + ); } -} -#[derive(Debug, Default)] -pub struct BTreeIndexPlugin; + /// Submit a prefetch task for a partition to the thread pool + pub async fn submit_prefetch_task(&self, partition_id: u64, batch_size: usize) -> Result<()> { + if self.config.prefetch_batches == 0 { + return Ok(()); + } -#[async_trait] -impl ScalarIndexPlugin for BTreeIndexPlugin { - fn new_training_request( - &self, - params: &str, - field: &Field, - ) -> Result> { - if field.data_type().is_nested() { - return Err(Error::InvalidInput { - source: "A btree index can only be created on a non-nested field.".into(), - location: location!(), - }); + let Some(state) = self.partition_states.get(&partition_id) else { + return Ok(()); + }; + + let reader = state.reader.clone(); + let buffer = state.buffer.clone(); + let running_tasks = state.running_tasks.clone(); + let next_prefetch_position = state.next_prefetch_position.clone(); + let total_rows = state.total_rows; + let effective_batch_size = self.config.prefetch_batches * batch_size; + + const MAX_BUFFER_SIZE: usize = 4; + const MAX_RUNNING_TASKS: usize = 2; + + // Clean up completed tasks and check limits + { + let mut tasks_guard = running_tasks.lock().await; + + // Remove completed tasks from the front + while let Some(task) = tasks_guard.front() { + if task.handle.is_finished() { + tasks_guard.pop_front(); + } else { + break; + } + } + + // Check if we have too many running tasks + if tasks_guard.len() >= MAX_RUNNING_TASKS { + debug!( + "Skipping prefetch for partition {} - too many running tasks ({})", + partition_id, + tasks_guard.len() + ); + return Ok(()); + } + + // Check if any running task already covers to the end of file + for task in tasks_guard.iter() { + if task.range.end >= total_rows { + debug!( + "Skipping prefetch for partition {} - task already covers to EOF (range {}..{})", + partition_id, task.range.start, task.range.end + ); + return Ok(()); + } + } } - let params = serde_json::from_str::(params)?; - Ok(Box::new(BTreeTrainingRequest::new(params))) - } + // Check if buffer is full + { + let buffer_guard = buffer.lock().await; + if buffer_guard.len() >= MAX_BUFFER_SIZE { + debug!( + "Skipping prefetch for partition {} - buffer full", + partition_id + ); + return Ok(()); + } + } - fn provides_exact_answer(&self) -> bool { - true - } + // Determine the next range to prefetch + let next_range = { + let mut pos_guard = next_prefetch_position.lock().await; + let start_pos = *pos_guard; - fn version(&self) -> u32 { - 0 - } + if start_pos >= total_rows { + debug!( + "Skipping prefetch for partition {} - no more data to prefetch", + partition_id + ); + return Ok(()); + } - fn new_query_parser( - &self, - index_name: String, - _index_details: &prost_types::Any, - ) -> Option> { - Some(Box::new(SargableQueryParser::new(index_name, false))) - } + let end_pos = std::cmp::min(start_pos + effective_batch_size, total_rows); + *pos_guard = end_pos; // Update next prefetch position + start_pos..end_pos + }; - async fn train_index( - &self, - data: SendableRecordBatchStream, - index_store: &dyn IndexStore, - request: Box, - ) -> Result { - let request = request - .as_any() - .downcast_ref::() - .unwrap(); - let value_type = data - .schema() - .field_with_name(VALUE_COLUMN_NAME)? - .data_type() - .clone(); - let flat_index_trainer = FlatIndexMetadata::new(value_type); - train_btree_index( - data, - &flat_index_trainer, - index_store, - request - .parameters - .zone_size - .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), - ) - .await?; - Ok(CreatedIndex { - index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), - index_version: BTREE_INDEX_VERSION, - }) + // Check if this range is already being prefetched + { + let tasks_guard = running_tasks.lock().await; + + // Check for range overlap + for task in tasks_guard.iter() { + if ranges_overlap(&task.range, &next_range) { + debug!( + "Skipping prefetch for partition {} - range {}..{} overlaps with running task {}..{}", + partition_id, next_range.start, next_range.end, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // All checks passed, create the actual prefetch task (only this part is async) + let range_clone = next_range.clone(); + let running_tasks_for_cleanup = running_tasks.clone(); + + let prefetch_task = spawn_btree_prefetch(async move { + // Perform the actual read + match reader.read_range(range_clone.clone(), None).await { + Ok(batch) => { + let entry = BufferEntry { + batch, + start_row: range_clone.start, + end_row: range_clone.end, + }; + + // Add to buffer + { + let mut buffer_guard = buffer.lock().await; + buffer_guard.push_back(entry); + } + + debug!( + "Prefetched {} rows ({}..{}) for partition {}", + range_clone.end - range_clone.start, + range_clone.start, + range_clone.end, + partition_id + ); + } + Err(err) => { + warn!( + "Prefetch task failed for partition {} range {}..{}: {}", + partition_id, range_clone.start, range_clone.end, err + ); + } + } + + // Remove this task from running tasks when completed + { + let mut tasks_guard = running_tasks_for_cleanup.lock().await; + tasks_guard.retain(|task| !task.handle.is_finished()); + } + }); + + // Add the task to running tasks + { + let mut tasks_guard = running_tasks.lock().await; + tasks_guard.push_back(RunningPrefetchTask { + handle: prefetch_task, + range: next_range.clone(), + }); + } + + debug!( + "Submitted prefetch task for partition {} range {}..{}", + partition_id, next_range.start, next_range.end + ); + + Ok(()) } - async fn load_index( + /// Get data from buffer or fallback to direct read + pub async fn get_data_with_fallback( &self, - index_store: Arc, - _index_details: &prost_types::Any, - frag_reuse_index: Option>, - cache: LanceCache, - ) -> Result> { - Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + partition_id: u64, + start_row: usize, + end_row: usize, + ) -> Result { + if let Some(state) = self.partition_states.get(&partition_id) { + // First try to get from buffer + { + let mut buffer_guard = state.buffer.lock().await; + + // Remove outdated entries from the front + while let Some(entry) = buffer_guard.front() { + if entry.end_row <= start_row { + buffer_guard.pop_front(); + } else { + break; + } + } + + // Check if we have suitable data in buffer + if let Some(entry) = buffer_guard.front() { + if entry.start_row <= start_row && entry.end_row >= end_row { + // Found matching data, extract it + let entry = buffer_guard.pop_front().unwrap(); + drop(buffer_guard); + + let slice_start = start_row - entry.start_row; + let slice_len = end_row - start_row; + + debug!( + "Using buffered data for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + return Ok(entry.batch.slice(slice_start, slice_len)); + } + } + } + + // Fallback to direct read + debug!( + "Direct read fallback for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + state.reader.read_range(start_row..end_row, None).await + } else { + Err(Error::Internal { + message: format!("Partition {} not found in prefetch manager", partition_id), + location: location!(), + }) + } + } +} + +impl Default for PrefetchConfig { + fn default() -> Self { + Self { + prefetch_batches: 1, + } + } +} + +impl PrefetchConfig { + /// Set the prefetch batch count + pub fn with_prefetch_batch(&self, batch_count: usize) -> Self { + Self { + prefetch_batches: batch_count, + } + } +} + +/// Simplified partition iterator with immediate loading since all partitions need to be accessed +struct PartitionIterator { + reader: Arc, + current_batch: Option, + current_position: usize, + rows_read: usize, + partition_id: u64, + batch_size: u64, +} + +impl PartitionIterator { + async fn new( + store: Arc, + page_file_name: String, + partition_id: u64, + batch_size: u64, + ) -> Result { + let reader = store.open_index_file(&page_file_name).await?; + Ok(Self { + reader, + current_batch: None, + current_position: 0, + rows_read: 0, + partition_id, + batch_size, + }) + } + + /// Get the next element, working with the prefetch manager + async fn next( + &mut self, + prefetch_manager: &PrefetchManager, + ) -> Result> { + // Load new batch if current one is exhausted + if self.needs_new_batch() { + if self.rows_read >= self.reader.num_rows() { + return Ok(None); + } + self.load_next_batch(prefetch_manager).await?; + + // Submit next prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } else { + // Check if we've read half of the current batch, submit next prefetch task + let batch_half = self.current_batch.as_ref().unwrap().num_rows() / 2; + if self.current_position == batch_half && batch_half > 0 { + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } + } + + // Extract next value from current batch + if let Some(batch) = &self.current_batch { + let value = ScalarValue::try_from_array(batch.column(0), self.current_position)?; + let row_id = ScalarValue::try_from_array(batch.column(1), self.current_position)?; + self.current_position += 1; + self.rows_read += 1; + Ok(Some((value, row_id))) + } else { + Ok(None) + } + } + + /// Check if we need to load a new batch + fn needs_new_batch(&self) -> bool { + self.current_batch.is_none() + || self.current_position >= self.current_batch.as_ref().unwrap().num_rows() + } + + async fn load_next_batch(&mut self, prefetch_manager: &PrefetchManager) -> Result<()> { + let remaining_rows = self.reader.num_rows() - self.rows_read; + if remaining_rows == 0 { + self.current_batch = None; + return Ok(()); + } + + let rows_to_read = std::cmp::min(self.batch_size as usize, remaining_rows); + let end_row = self.rows_read + rows_to_read; + + // Use the new fallback mechanism - try buffer first, then direct read + let batch = prefetch_manager + .get_data_with_fallback(self.partition_id, self.rows_read, end_row) + .await?; + + self.current_batch = Some(batch); + self.current_position = 0; + + Ok(()) + } + + fn get_reader(&self) -> Arc { + self.reader.clone() + } +} + +/// Heap elements, used for priority queues in multi-way merging +#[derive(Debug)] +struct HeapElement { + value: ScalarValue, + row_id: ScalarValue, + partition_id: u64, +} + +impl PartialEq for HeapElement { + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } +} + +impl Eq for HeapElement {} + +impl PartialOrd for HeapElement { + fn partial_cmp(&self, other: &Self) -> Option { + // Note: BinaryHeap is a maximum heap, we need a minimum heap, + // so reverse the comparison result + other.value.partial_cmp(&self.value) + } +} + +impl Ord for HeapElement { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap_or(Ordering::Equal) + } +} + +async fn merge_page( + part_lookup_files: &[String], + page_files_map: &HashMap, + store: &Arc, + batch_size: u64, + page_file: &mut Box, + arrow_schema: Arc, + prefetch_config: PrefetchConfig, +) -> Result> { + let mut lookup_entries = Vec::new(); + let mut page_idx = 0u32; + + debug!( + "Starting multi-way merge with {} partitions using prefetch manager", + part_lookup_files.len() + ); + + // Create prefetch manager + let mut prefetch_manager = PrefetchManager::new(prefetch_config.clone()); + + // Directly create iterators and read first element + let mut partition_map = HashMap::new(); + let mut heap = BinaryHeap::new(); + + debug!("Initializing {} partitions", part_lookup_files.len()); + + // Initialize all partitions + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + let page_file_name = page_files_map + .get(&partition_id) + .ok_or_else(|| Error::Internal { + message: format!("Page file not found for partition ID: {}", partition_id), + location: location!(), + })? + .to_string(); + + let mut iterator = + PartitionIterator::new(store.clone(), page_file_name, partition_id, batch_size).await?; + + // Initialize partition in prefetch manager + let reader = iterator.get_reader(); + prefetch_manager.initialize_partition(partition_id, reader); + + // Submit initial prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(partition_id, batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + partition_id, err + ); + } + + let first_element = iterator.next(&prefetch_manager).await?; + + if let Some((value, row_id)) = first_element { + // Put the first element into the heap + heap.push(HeapElement { + value, + row_id, + partition_id, + }); + } + + partition_map.insert(partition_id, iterator); + } + + debug!( + "Initialized {} partitions, heap size: {}", + partition_map.len(), + heap.len() + ); + + let mut current_batch_rows = Vec::with_capacity(batch_size as usize); + let mut total_merged = 0usize; + + // Multi-way merge main loop + while let Some(min_element) = heap.pop() { + // Add current minimum element to batch + current_batch_rows.push((min_element.value, min_element.row_id)); + total_merged += 1; + + // Read next element from corresponding partition + if let Some(iterator) = partition_map.get_mut(&min_element.partition_id) { + if let Some((next_value, next_row_id)) = iterator.next(&prefetch_manager).await? { + heap.push(HeapElement { + value: next_value, + row_id: next_row_id, + partition_id: min_element.partition_id, + }); + } + } + + // Write when batch reaches specified size + if current_batch_rows.len() >= batch_size as usize { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + } + + // Write the remaining data + if !current_batch_rows.is_empty() { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + + debug!( + "Completed multi-way merge: merged {} rows into {} lookup entries", + total_merged, + lookup_entries.len() + ); + Ok(lookup_entries) +} + +/// Helper function to prepare batch data in parallel +async fn prepare_batch_data( + batch_rows: Vec<(ScalarValue, ScalarValue)>, + arrow_schema: Arc, + page_idx: u32, +) -> Result<(RecordBatch, (ScalarValue, ScalarValue, u32, u32))> { + if batch_rows.is_empty() { + return Err(Error::Internal { + message: "Cannot prepare empty batch".to_string(), + location: location!(), + }); + } + + // Parallelize data preparation + let (values, row_ids): (Vec<_>, Vec<_>) = batch_rows.into_iter().unzip(); + + // Convert to arrays in parallel using rayon or manually spawn tasks + let values_array = ScalarValue::iter_to_array(values.into_iter())?; + let row_ids_array = ScalarValue::iter_to_array(row_ids.into_iter())?; + + let batch = RecordBatch::try_new(arrow_schema, vec![values_array, row_ids_array])?; + + // Calculate min/max/null_count for lookup entry + let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; + let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; + let null_count = batch.column(0).null_count() as u32; + + let lookup_entry = (min_val, max_val, null_count, page_idx); + + Ok((batch, lookup_entry)) +} + +/// Helper function to write a batch and create lookup entry +async fn write_batch_and_lookup_entry( + batch_rows: &mut Vec<(ScalarValue, ScalarValue)>, + page_file: &mut Box, + arrow_schema: &Arc, + lookup_entries: &mut Vec<(ScalarValue, ScalarValue, u32, u32)>, + page_idx: &mut u32, +) -> Result<()> { + if batch_rows.is_empty() { + return Ok(()); + } + + // Take ownership of the batch data + let batch_data = std::mem::take(batch_rows); + let current_page_idx = *page_idx; + + // Prepare batch data + let (batch, lookup_entry) = + prepare_batch_data(batch_data, arrow_schema.clone(), current_page_idx).await?; + + lookup_entries.push(lookup_entry); + page_file.write_record_batch(batch).await?; + *page_idx += 1; + + Ok(()) +} + +pub(crate) fn part_page_data_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_PAGES_NAME) +} + +pub(crate) fn part_lookup_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME) +} + +/// A stream that reads the original training data back out of the index +/// +/// This is used for updating the index +struct IndexReaderStream { + reader: Arc, + batch_size: u64, + num_batches: u32, + batch_idx: u32, +} + +impl IndexReaderStream { + async fn new(reader: Arc, batch_size: u64) -> Self { + let num_batches = reader.num_batches(batch_size).await; + Self { + reader, + batch_size, + num_batches, + batch_idx: 0, + } + } +} + +impl Stream for IndexReaderStream { + type Item = BoxFuture<'static, Result>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + if this.batch_idx >= this.num_batches { + return std::task::Poll::Ready(None); + } + let batch_num = this.batch_idx; + this.batch_idx += 1; + let reader_copy = this.reader.clone(); + let batch_size = this.batch_size; + let read_task = async move { + reader_copy + .read_record_batch(batch_num as u64, batch_size) + .await + } + .boxed(); + std::task::Poll::Ready(Some(read_task)) + } +} + +/// Parameters for a btree index +#[derive(Debug, Serialize, Deserialize)] +pub struct BTreeParameters { + /// The number of rows to include in each zone + pub zone_size: Option, +} + +struct BTreeTrainingRequest { + parameters: BTreeParameters, + criteria: TrainingCriteria, +} + +impl BTreeTrainingRequest { + pub fn new(parameters: BTreeParameters) -> Self { + Self { + parameters, + // BTree indexes need data sorted by the value column + criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + } + } +} + +impl TrainingRequest for BTreeTrainingRequest { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn criteria(&self) -> &TrainingCriteria { + &self.criteria + } +} + +#[derive(Debug, Default)] +pub struct BTreeIndexPlugin; + +#[async_trait] +impl ScalarIndexPlugin for BTreeIndexPlugin { + fn new_training_request( + &self, + params: &str, + field: &Field, + ) -> Result> { + if field.data_type().is_nested() { + return Err(Error::InvalidInput { + source: "A btree index can only be created on a non-nested field.".into(), + location: location!(), + }); + } + + let params = serde_json::from_str::(params)?; + Ok(Box::new(BTreeTrainingRequest::new(params))) + } + + fn provides_exact_answer(&self) -> bool { + true + } + + fn version(&self) -> u32 { + 0 + } + + fn new_query_parser( + &self, + index_name: String, + _index_details: &prost_types::Any, + ) -> Option> { + Some(Box::new(SargableQueryParser::new(index_name, false))) + } + + async fn train_index( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + ) -> Result { + let request = request + .as_any() + .downcast_ref::() + .unwrap(); + let value_type = data + .schema() + .field_with_name(VALUE_COLUMN_NAME)? + .data_type() + .clone(); + let flat_index_trainer = FlatIndexMetadata::new(value_type); + train_btree_index( + data, + &flat_index_trainer, + index_store, + request + .parameters + .zone_size + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), + fragment_ids, + ) + .await?; + Ok(CreatedIndex { + index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), + index_version: BTREE_INDEX_VERSION, + }) + } + + async fn load_index( + &self, + index_store: Arc, + _index_details: &prost_types::Any, + frag_reuse_index: Option>, + cache: LanceCache, + ) -> Result> { + Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) } } @@ -1594,184 +2522,861 @@ mod tests { }, }; - use super::{train_btree_index, OrderableScalarValue}; + use super::{ + part_lookup_file_path, part_page_data_file_path, train_btree_index, OrderableScalarValue, + DEFAULT_BTREE_BATCH_SIZE, + }; + + #[test] + fn test_scalar_value_size() { + let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); + let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(0); 128])], + 128, + ), + ))) + .deep_size_of(); + + // deep_size_of should account for the rust type overhead + assert!(size_of_i32 > 4); + assert!(size_of_many_i32 > 128 * 4); + } + + #[tokio::test] + async fn test_null_ids() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Generate 50,000 rows of random data with 80% nulls + let stream = gen_batch() + .col( + "value", + array::rand::().with_nulls(&[true, false, false, false, false]), + ) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(5000), BatchCount::from(10)); + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(index.page_lookup.null_pages.len(), 10); + + let remap_dir = Arc::new(tempdir().unwrap()); + let remap_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(remap_dir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Remap with a no-op mapping. The remapped index should be identical to the original + index + .remap(&HashMap::default(), remap_store.as_ref()) + .await + .unwrap(); + + let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(remap_index.page_lookup, index.page_lookup); + + let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + + assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + + let original_data = original_pages + .read_record_batch(0, original_pages.num_rows() as u64) + .await + .unwrap(); + let remapped_data = remapped_pages + .read_record_batch(0, remapped_pages.num_rows() as u64) + .await + .unwrap(); + + assert_eq!(original_data, remapped_data); + } + + #[tokio::test] + async fn test_nan_ordering() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let values = vec![ + 0.0, + 1.0, + 2.0, + 3.0, + f64::NAN, + f64::NEG_INFINITY, + f64::INFINITY, + ]; + + // This is a bit overkill but we've had bugs in the past where DF's sort + // didn't agree with Arrow's sort so we do an end-to-end test here + // and use DF to sort the data like we would in a real dataset. + let data = gen_batch() + .col("value", array::cycle::(values.clone())) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(10), BatchCount::from(100)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + .await + .unwrap(); + + for (idx, value) in values.into_iter().enumerate() { + let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert_eq!( + result, + SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) + ); + } + } + + #[tokio::test] + async fn test_page_cache() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let data = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(1000), BatchCount::from(10)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load( + test_store, + None, + LanceCache::with_capacity(100 * 1024 * 1024), + ) + .await + .unwrap(); + + let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); + let metrics = LocalMetricsCollector::default(); + let query1 = index.search(&query, &metrics); + let query2 = index.search(&query, &metrics); + tokio::join!(query1, query2).0.unwrap(); + assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + } + + /// Test that fragment-based btree index construction produces exactly the same results as building a complete index + #[tokio::test] + async fn test_fragment_btree_index_consistency() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Method 1: Build complete index directly using the same data + // Create deterministic data for comparison - use 2 * DEFAULT_BTREE_BATCH_SIZE for testing + let total_count = (2 * DEFAULT_BTREE_BATCH_SIZE) as u64; + let full_data_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(total_count / 2), BatchCount::from(2)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using the same data split into fragments + // Create fragment 1 index - first half of the data (0 to DEFAULT_BTREE_BATCH_SIZE-1) + let half_count = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), // fragment_id = 1 + ) + .await + .unwrap(); + + // Create fragment 2 index - second half of the data (DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1) + let start_val = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_second_half: Vec = (start_val..end_val).collect(); + let row_ids_second_half: Vec = (start_val as u64..end_val as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_second_half)) + .col("_rowid", array::cycle::(row_ids_second_half)) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), // fragment_id = 2 + ) + .await + .unwrap(); + + // Merge the fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + // Test queries one by one to identify the exact problem + + // Test 1: Query for value 0 (should be in first page) + let query_0 = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_0 = full_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_0 = merged_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(full_result_0, merged_result_0, "Query for value 0 failed"); + + // Test 2: Query for value in middle of first batch (should be in first page) + let mid_first_batch = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_first = SargableQuery::Equals(ScalarValue::Int32(Some(mid_first_batch))); + let full_result_mid_first = full_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_first = merged_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_first, merged_result_mid_first, + "Query for value {} failed", + mid_first_batch + ); + + // Test 3: Query for first value in second batch (should be in second page) + let first_second_batch = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_first_second = + SargableQuery::Equals(ScalarValue::Int32(Some(first_second_batch))); + let full_result_first_second = full_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_first_second = merged_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_first_second, merged_result_first_second, + "Query for value {} failed", + first_second_batch + ); - #[test] - fn test_scalar_value_size() { - let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); - let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( - FixedSizeListArray::from_iter_primitive::( - vec![Some(vec![Some(0); 128])], - 128, - ), - ))) - .deep_size_of(); + // Test 4: Query for value in middle of second batch (should be in second page) + let mid_second_batch = (DEFAULT_BTREE_BATCH_SIZE + DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_second = SargableQuery::Equals(ScalarValue::Int32(Some(mid_second_batch))); - // deep_size_of should account for the rust type overhead - assert!(size_of_i32 > 4); - assert!(size_of_many_i32 > 128 * 4); + let full_result_mid_second = full_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_second = merged_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_second, merged_result_mid_second, + "Query for value {} failed", + mid_second_batch + ); } #[tokio::test] - async fn test_null_ids() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + async fn test_fragment_btree_index_boundary_queries() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - // Generate 50,000 rows of random data with 80% nulls - let stream = gen_batch() - .col( - "value", - array::rand::().with_nulls(&[true, false, false, false, false]), - ) + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Use 3 * DEFAULT_BTREE_BATCH_SIZE for more comprehensive boundary testing + let total_count = (3 * DEFAULT_BTREE_BATCH_SIZE) as u64; + + // Method 1: Build complete index directly + let full_data_gen = gen_batch() + .col("value", array::step::()) .col("_rowid", array::step::()) - .into_df_stream(RowCount::from(5000), BatchCount::from(10)); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + .into_df_stream(RowCount::from(total_count / 3), BatchCount::from(3)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using 3 fragments + // Fragment 1: 0 to DEFAULT_BTREE_BATCH_SIZE-1 + let fragment_size = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), + ) + .await + .unwrap(); + + // Fragment 2: DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val2 = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val2 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment2: Vec = (start_val2..end_val2).collect(); + let row_ids_fragment2: Vec = (start_val2 as u64..end_val2 as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_fragment2)) + .col("_rowid", array::cycle::(row_ids_fragment2)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), + ) + .await + .unwrap(); - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000) + // Fragment 3: 2*DEFAULT_BTREE_BATCH_SIZE to 3*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val3 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let end_val3 = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment3: Vec = (start_val3..end_val3).collect(); + let row_ids_fragment3: Vec = (start_val3 as u64..end_val3 as u64).collect(); + let fragment3_gen = gen_batch() + .col("value", array::cycle::(values_fragment3)) + .col("_rowid", array::cycle::(row_ids_fragment3)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment3_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment3_gen.schema(), + fragment3_gen, + )); + + train_btree_index( + fragment3_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![3]), + ) + .await + .unwrap(); + + // Merge all fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + part_page_data_file_path(3 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + part_lookup_file_path(3 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - assert_eq!(index.page_lookup.null_pages.len(), 10); + // === Boundary Value Tests === - let remap_dir = Arc::new(tempdir().unwrap()); - let remap_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(remap_dir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 1: Query minimum value (boundary: data start) + let query_min = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_min = full_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_min = merged_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_min, merged_result_min, + "Query for minimum value 0 failed" + ); - // Remap with a no-op mapping. The remapped index should be identical to the original - index - .remap(&HashMap::default(), remap_store.as_ref()) + // Test 2: Query maximum value (boundary: data end) + let max_val = (3 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val))); + let full_result_max = full_index + .search(&query_max, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_max = merged_index + .search(&query_max, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_max, merged_result_max, + "Query for maximum value {} failed", + max_val + ); - let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + // Test 3: Query fragment boundary value (last value of first fragment) + let fragment1_last = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag1_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment1_last))); + let full_result_frag1_last = full_index + .search(&query_frag1_last, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_frag1_last = merged_index + .search(&query_frag1_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag1_last, merged_result_frag1_last, + "Query for fragment 1 last value {} failed", + fragment1_last + ); - assert_eq!(remap_index.page_lookup, index.page_lookup); + // Test 4: Query fragment boundary value (first value of second fragment) + let fragment2_first = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_frag2_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_first))); + let full_result_frag2_first = full_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_first = merged_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_first, merged_result_frag2_first, + "Query for fragment 2 first value {} failed", + fragment2_first + ); - let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); - let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + // Test 5: Query fragment boundary value (last value of second fragment) + let fragment2_last = (2 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag2_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_last))); + let full_result_frag2_last = full_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_last = merged_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_last, merged_result_frag2_last, + "Query for fragment 2 last value {} failed", + fragment2_last + ); - assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + // Test 6: Query fragment boundary value (first value of third fragment) + let fragment3_first = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_frag3_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment3_first))); + let full_result_frag3_first = full_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag3_first = merged_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag3_first, merged_result_frag3_first, + "Query for fragment 3 first value {} failed", + fragment3_first + ); - let original_data = original_pages - .read_record_batch(0, original_pages.num_rows() as u64) + // === Non-existent Value Tests === + + // Test 7: Query value below minimum + let query_below_min = SargableQuery::Equals(ScalarValue::Int32(Some(-1))); + let full_result_below = full_index + .search(&query_below_min, &NoOpMetricsCollector) .await .unwrap(); - let remapped_data = remapped_pages - .read_record_batch(0, remapped_pages.num_rows() as u64) + let merged_result_below = merged_index + .search(&query_below_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_below, merged_result_below, + "Query for value below minimum (-1) failed" + ); + + // Test 8: Query value above maximum + let query_above_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val + 1))); + let full_result_above = full_index + .search(&query_above_max, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_above = merged_index + .search(&query_above_max, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_above, + merged_result_above, + "Query for value above maximum ({}) failed", + max_val + 1 + ); - assert_eq!(original_data, remapped_data); - } + // === Range Query Tests === - #[tokio::test] - async fn test_nan_ordering() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 9: Cross-fragment range query (from first fragment to second fragment) + let range_start = (DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let range_end = (DEFAULT_BTREE_BATCH_SIZE + 100) as i32; + let query_cross_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(range_end))), + ); + let full_result_cross = full_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_cross = merged_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_cross, merged_result_cross, + "Cross-fragment range query [{}, {}] failed", + range_start, range_end + ); - let values = vec![ - 0.0, - 1.0, - 2.0, - 3.0, - f64::NAN, - f64::NEG_INFINITY, - f64::INFINITY, - ]; + // Test 10: Range query within single fragment + let single_frag_start = 100i32; + let single_frag_end = 200i32; + let query_single_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(single_frag_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(single_frag_end))), + ); + let full_result_single = full_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_single = merged_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_single, merged_result_single, + "Single fragment range query [{}, {}] failed", + single_frag_start, single_frag_end + ); - // This is a bit overkill but we've had bugs in the past where DF's sort - // didn't agree with Arrow's sort so we do an end-to-end test here - // and use DF to sort the data like we would in a real dataset. - let data = gen_batch() - .col("value", array::cycle::(values.clone())) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(10), BatchCount::from(100)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + // Test 11: Large range query spanning all fragments + let large_range_start = 100i32; + let large_range_end = (3 * DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let query_large_range = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(large_range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(large_range_end))), + ); + let full_result_large = full_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_large = merged_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_large, merged_result_large, + "Large range query [{}, {}] failed", + large_range_start, large_range_end + ); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + // === Range Boundary Query Tests === - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) + // Test 12: Less than query (implemented using range query, from minimum to specified value) + let lt_val = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_lt = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(lt_val))), + ); + let full_result_lt = full_index + .search(&query_lt, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_lt = merged_index + .search(&query_lt, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lt, merged_result_lt, + "Less than query (<{}) failed", + lt_val + ); - let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + // Test 13: Greater than query (implemented using range query, from specified value to maximum) + let gt_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let max_range_val = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gt = SargableQuery::Range( + std::collections::Bound::Excluded(ScalarValue::Int32(Some(gt_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gt = full_index + .search(&query_gt, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gt = merged_index + .search(&query_gt, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_gt, merged_result_gt, + "Greater than query (>{}) failed", + gt_val + ); - for (idx, value) in values.into_iter().enumerate() { - let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); - let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) - ); - } + // Test 14: Less than or equal query (implemented using range query, including boundary value) + let lte_val = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_lte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Included(ScalarValue::Int32(Some(lte_val))), + ); + let full_result_lte = full_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_lte = merged_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lte, merged_result_lte, + "Less than or equal query (<={}) failed", + lte_val + ); + + // Test 15: Greater than or equal query (implemented using range query, including boundary value) + let gte_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(gte_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gte = full_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gte = merged_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_gte, merged_result_gte, + "Greater than or equal query (>={}) failed", + gte_val + ); + } + + #[test] + fn test_extract_partition_id() { + // Test valid partition file names + assert_eq!( + super::extract_partition_id("part_123_page_data.lance").unwrap(), + 123 + ); + assert_eq!( + super::extract_partition_id("part_456_page_lookup.lance").unwrap(), + 456 + ); + assert_eq!( + super::extract_partition_id("part_4294967296_page_data.lance").unwrap(), + 4294967296 + ); + + // Test invalid file names + assert!(super::extract_partition_id("invalid_filename.lance").is_err()); + assert!(super::extract_partition_id("part_abc_page_data.lance").is_err()); + assert!(super::extract_partition_id("part_123").is_err()); + assert!(super::extract_partition_id("part_").is_err()); } #[tokio::test] - async fn test_page_cache() { + async fn test_cleanup_partition_files() { + use crate::scalar::lance_format::LanceIndexStore; + use lance_core::cache::LanceCache; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use std::sync::Arc; + use tempfile::tempdir; + + // Create a test store let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + let test_store: Arc = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), Path::from_filesystem_path(tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - let data = gen_batch() - .col("value", array::step::()) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(1000), BatchCount::from(10)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + // Test files with different patterns + let lookup_files = vec![ + "part_123_page_lookup.lance".to_string(), + "invalid_lookup_file.lance".to_string(), + "part_456_page_lookup.lance".to_string(), + ]; - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) - .await - .unwrap(); + let page_files = vec![ + "part_123_page_data.lance".to_string(), + "invalid_page_file.lance".to_string(), + "part_456_page_data.lance".to_string(), + ]; - let index = BTreeIndex::load( - test_store, - None, - LanceCache::with_capacity(100 * 1024 * 1024), - ) - .await - .unwrap(); + // The cleanup function should handle both valid and invalid file patterns gracefully + // This test mainly verifies that the function doesn't panic and handles edge cases + super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; - let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); - let metrics = LocalMetricsCollector::default(); - let query1 = index.search(&query, &metrics); - let query2 = index.search(&query, &metrics); - tokio::join!(query1, query2).0.unwrap(); - assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + // If we get here without panicking, the cleanup function handled all cases correctly + assert!(true); } } diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 4edb1cb6a0a..a4506020782 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -163,6 +163,7 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -170,7 +171,8 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { source: "must provide training request created by new_training_request".into(), location: location!(), })?; - Self::train_inverted_index(data, index_store, request.parameters.clone(), None).await + Self::train_inverted_index(data, index_store, request.parameters.clone(), fragment_ids) + .await } /// Load an index from storage diff --git a/rust/lance-index/src/scalar/json.rs b/rust/lance-index/src/scalar/json.rs index 0b8a43efbe7..e36feaacfc7 100644 --- a/rust/lance-index/src/scalar/json.rs +++ b/rust/lance-index/src/scalar/json.rs @@ -768,6 +768,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -797,7 +798,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { )?; let target_index = target_plugin - .train_index(converted_stream, index_store, target_request) + .train_index(converted_stream, index_store, target_request, fragment_ids) .await?; let index_details = crate::pb::JsonIndexDetails { diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index 542aa2bc97a..64e932c47c5 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -398,6 +398,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let schema = data.schema(); let field = schema @@ -427,7 +428,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { let data = unnest_chunks(data)?; let bitmap_plugin = BitmapIndexPlugin; bitmap_plugin - .train_index(data, index_store, request) + .train_index(data, index_store, request, fragment_ids) .await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pb::LabelListIndexDetails::default()) diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index d8a95de1eeb..4df502ead09 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -368,7 +368,7 @@ pub mod tests { ) .unwrap(); btree_plugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -866,6 +866,7 @@ pub mod tests { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); @@ -911,7 +912,7 @@ pub mod tests { .new_training_request("{}", &Field::new(VALUE_COLUMN_NAME, DataType::Int32, false)) .unwrap(); BitmapIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -1399,7 +1400,7 @@ pub mod tests { ) .unwrap(); LabelListIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index ff559dd9292..586b0a4da9a 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -1285,6 +1285,7 @@ impl ScalarIndexPlugin for NGramIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_ngram_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 022da729f0c..3880aad4dbe 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -119,6 +119,7 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result; /// Returns true if the index returns an exact answer (e.g. not AtMost) diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 748ab003863..c9097cbb8cc 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -961,6 +961,7 @@ impl ScalarIndexPlugin for ZoneMapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + _fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 0742ff7f878..58b94f56318 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -71,6 +71,7 @@ impl BenchmarkFixture { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index cdebc399547..ccae72a4865 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -284,12 +284,11 @@ pub(super) async fn build_scalar_index( training_request.criteria(), None, train, - fragment_ids, + fragment_ids.clone(), ) .await?; - plugin - .train_index(training_data, &index_store, training_request) + .train_index(training_data, &index_store, training_request, fragment_ids) .await } From 1125dc19770e65ec56426d2fdb5f22967905d7d8 Mon Sep 17 00:00:00 2001 From: xloya Date: Fri, 5 Sep 2025 12:48:21 +0800 Subject: [PATCH 02/13] support btree distributely --- java/.gitignore | 3 + python/python/lance/dataset.py | 36 +- python/python/lance/lance/__init__.pyi | 4 +- python/python/tests/test_scalar_index.py | 416 +++- python/src/dataset.rs | 126 +- rust/lance-index/src/scalar/bitmap.rs | 1 + rust/lance-index/src/scalar/btree.rs | 2185 ++++++++++++++++--- rust/lance-index/src/scalar/inverted.rs | 4 +- rust/lance-index/src/scalar/json.rs | 3 +- rust/lance-index/src/scalar/label_list.rs | 3 +- rust/lance-index/src/scalar/lance_format.rs | 7 +- rust/lance-index/src/scalar/ngram.rs | 1 + rust/lance-index/src/scalar/registry.rs | 1 + rust/lance-index/src/scalar/zonemap.rs | 1 + rust/lance/benches/scalar_index.rs | 1 + rust/lance/src/index/scalar.rs | 5 +- 16 files changed, 2462 insertions(+), 335 deletions(-) diff --git a/java/.gitignore b/java/.gitignore index d9074bd2835..f134c3c1a74 100644 --- a/java/.gitignore +++ b/java/.gitignore @@ -1,2 +1,5 @@ *.iml .java-version +.project +.settings +.classpath \ No newline at end of file diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 376476e43af..54bd432201a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2731,8 +2731,40 @@ def prewarm_index(self, name: str): """ return self._ds.prewarm_index(name) - def merge_index_metadata(self, index_uuid: str): - return self._ds.merge_index_metadata(index_uuid) + def merge_index_metadata( + self, + index_uuid: str, + index_type: Union[ + Literal["BTREE"], + Literal["INVERTED"], + ], + prefetch_batch: Optional[int] = None, + ): + """ + Merge an index which not commit at present. + + Parameters + ---------- + index_uuid: str + The uuid of the index which want to merge. + index_type: Literal["BTREE", "INVERTED"] + The type of the index. + prefetch_batch: int, optional + The number of prefetch batches of sub-page files for merging. + Default 1. + """ + index_type = index_type.upper() + if index_type not in [ + "BTREE", + "INVERTED", + ]: + raise NotImplementedError( + ( + 'Only "BTREE" or "INVERTED" are supported for ' + f"merge index metadata. Received {index_type}", + ) + ) + return self._ds.merge_index_metadata(index_uuid, index_type, prefetch_batch) def session(self) -> Session: """ diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index c2a72b7b1b5..0bae8e2f1aa 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -282,7 +282,9 @@ class _Dataset: ): ... def drop_index(self, name: str): ... def prewarm_index(self, name: str): ... - def merge_index_metadata(self, index_uuid: str): ... + def merge_index_metadata( + self, index_uuid: str, index_type: str, prefetch_batch: Optional[int] = None + ): ... def count_fragments(self) -> int: ... def num_small_files(self, max_rows_per_group: int) -> int: ... def get_fragments(self) -> List[_Fragment]: ... diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index c2370a17a9e..5f92a1e4d11 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -1982,7 +1982,7 @@ def build_distributed_fts_index( ) # Merge the inverted index metadata - dataset.merge_index_metadata(index_id) + dataset.merge_index_metadata(index_id, index_type="INVERTED") # Create Index object for commit field_id = dataset.schema.get_field_index(column) @@ -2856,7 +2856,7 @@ def test_distribute_fts_index_build(tmp_path): print(f"Fragment {fragment_id} index created successfully") # Merge the inverted index metadata - ds.merge_index_metadata(index_id) + ds.merge_index_metadata(index_id, index_type="INVERTED") # Create an Index object using the new dataclass format from lance.dataset import Index @@ -2983,3 +2983,415 @@ def test_backward_compatibility_no_fragment_ids(tmp_path): results = ds.scanner(full_text_query=search_word).to_table() assert results.num_rows > 0 + + +def test_distribute_btree_index_build(tmp_path): + """ + Test distributed B-tree index build similar to test_distribute_fts_index_build. + This test creates B-tree indices on individual fragments and then + commits them as a single index. + """ + # Generate test dataset with multiple fragments + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=4, rows_per_fragment=10000 + ) + + import uuid + + index_id = str(uuid.uuid4()) + print(f"Using index ID: {index_id}") + index_name = "btree_multiple_fragment_idx" + + fragments = ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + for fragment in ds.get_fragments(): + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + # Create B-tree scalar index for each fragment + # Use the same index_name for all fragments (like in FTS test) + ds.create_scalar_index( + column="id", # Use integer column for B-tree + index_type="BTREE", + name=index_name, + replace=False, + fragment_uuid=index_id, + fragment_ids=[fragment_id], + ) + + # For fragment-level indexing, we expect the method to return successfully + # but not commit the index yet + print(f"Fragment {fragment_id} B-tree index created successfully") + + # Merge the B-tree index metadata + ds.merge_index_metadata(index_id, index_type="BTREE") + print(ds.uri) + + # Create an Index object using the new dataclass format + from lance.dataset import Index + + # Get the schema field for the indexed column + field_id = ds.schema.get_field_index("id") + + index = Index( + uuid=index_id, + name=index_name, + fields=[field_id], # Use field index instead of field object + dataset_version=ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Create the index operation + create_index_op = lance.LanceOperation.CreateIndex( + new_indices=[index], + removed_indices=[], + ) + + # Commit the index + ds_committed = lance.LanceDataset.commit( + ds.uri, + create_index_op, + read_version=ds.version, + ) + + print("Successfully committed multiple fragment B-tree index") + + # Verify the index was created and is functional + indices = ds_committed.list_indices() + assert len(indices) > 0, "No indices found after commit" + + # Find our index + our_index = None + for idx in indices: + if idx["name"] == index_name: + our_index = idx + break + + assert our_index is not None, f"Index '{index_name}' not found in indices list" + assert our_index["type"] == "BTree", ( + f"Expected BTree index, got {our_index['type']}" + ) + + # Test that the index works for searching + # Test exact equality queries + test_id = 100 # Should be in first fragment + results = ds_committed.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + print(f"Search for id = {test_id} returned {results.num_rows} results") + assert results.num_rows > 0, f"No results found for id = {test_id}" + + # Test range queries across fragments + results_range = ds_committed.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + print(f"Range query returned {results_range.num_rows} results") + assert results_range.num_rows > 0, "No results found for range query" + + # Compare with complete index results to ensure consistency + # Create a reference dataset with complete index + reference_ds = generate_multi_fragment_dataset( + tmp_path / "reference", num_fragments=4, rows_per_fragment=10000 + ) + + # Create complete B-tree index for comparison + reference_ds.create_scalar_index( + column="id", + index_type="BTREE", + name="reference_btree_idx", + ) + + # Compare exact query results + reference_results = reference_ds.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + assert results.num_rows == reference_results.num_rows, ( + f"Distributed index returned {results.num_rows} results, " + f"but complete index returned {reference_results.num_rows} results" + ) + + # Compare range query results + reference_range_results = reference_ds.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + assert results_range.num_rows == reference_range_results.num_rows, ( + f"Distributed index range query returned {results_range.num_rows} results, " + f"but complete index returned {reference_range_results.num_rows} results" + ) + + +def test_btree_precise_query_comparison(tmp_path): + """ + Precise comparison test between fragment-level B-tree index and complete + B-tree index. + This test creates identical datasets and compares query results in detail. + """ + # Test configuration + num_fragments = 3 + rows_per_fragment = 10000 + total_rows = num_fragments * rows_per_fragment + + print( + f"Creating datasets with {num_fragments} fragments," + f" {rows_per_fragment} rows each" + ) + + # Create dataset for fragment-level indexing + fragment_ds = generate_multi_fragment_dataset( + tmp_path / "fragment", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + # Create dataset for complete indexing (same data structure) + complete_ds = generate_multi_fragment_dataset( + tmp_path / "complete", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + import uuid + + # Build fragment-level B-tree index + fragment_index_id = str(uuid.uuid4()) + fragment_index_name = "fragment_btree_precise_test" + + fragments = fragment_ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + # Create fragment-level indices + for fragment in fragments: + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + fragment_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=fragment_index_name, + replace=False, + fragment_uuid=fragment_index_id, + fragment_ids=[fragment_id], + ) + + # Merge fragment indices + fragment_ds.merge_index_metadata(fragment_index_id, index_type="BTREE") + + # Create Index object for fragment-based index + from lance.dataset import Index + + field_id = fragment_ds.schema.get_field_index("id") + + fragment_index = Index( + uuid=fragment_index_id, + name=fragment_index_name, + fields=[field_id], + dataset_version=fragment_ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Commit fragment-based index + create_fragment_index_op = lance.LanceOperation.CreateIndex( + new_indices=[fragment_index], + removed_indices=[], + ) + + fragment_ds_committed = lance.LanceDataset.commit( + fragment_ds.uri, + create_fragment_index_op, + read_version=fragment_ds.version, + ) + + # Build complete B-tree index + complete_index_name = "complete_btree_precise_test" + complete_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=complete_index_name, + ) + + print("Both indices created successfully") + + # Detailed query comparison tests + test_cases = [ + # Test 1: Boundary values at fragment edges + {"name": "First value", "filter": "id = 0"}, + {"name": "Fragment 0 last value", "filter": f"id = {rows_per_fragment - 1}"}, + {"name": "Fragment 1 first value", "filter": f"id = {rows_per_fragment}"}, + { + "name": "Fragment 1 last value", + "filter": f"id = {2 * rows_per_fragment - 1}", + }, + {"name": "Fragment 2 first value", "filter": f"id = {2 * rows_per_fragment}"}, + {"name": "Last value", "filter": f"id = {total_rows - 1}"}, + # Test 2: Values in the middle of fragments + {"name": "Fragment 0 middle", "filter": f"id = {rows_per_fragment // 2}"}, + { + "name": "Fragment 1 middle", + "filter": f"id = {rows_per_fragment + rows_per_fragment // 2}", + }, + { + "name": "Fragment 2 middle", + "filter": f"id = {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 3: Range queries within single fragments + {"name": "Range within fragment 0", "filter": "id >= 10 AND id < 20"}, + { + "name": "Range within fragment 1", + "filter": f"id >= {rows_per_fragment + 10}" + f" AND id < {rows_per_fragment + 20}", + }, + { + "name": "Range within fragment 2", + "filter": f"id >= {2 * rows_per_fragment + 10}" + f" AND id < {2 * rows_per_fragment + 20}", + }, + # Test 4: Range queries spanning multiple fragments + { + "name": "Cross fragment 0-1", + "filter": f"id >= {rows_per_fragment - 5} AND id < {rows_per_fragment + 5}", + }, + { + "name": "Cross fragment 1-2", + "filter": f"id >= {2 * rows_per_fragment - 5}" + f" AND id < {2 * rows_per_fragment + 5}", + }, + { + "name": "Cross all fragments", + "filter": f"id >= {rows_per_fragment // 2} AND" + f" id < {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 5: Edge cases + {"name": "Non-existent small value", "filter": "id = -1"}, + {"name": "Non-existent large value", "filter": f"id = {total_rows + 100}"}, + {"name": "Large range", "filter": f"id >= 0 AND id < {total_rows}"}, + # Test 6: Comparison operators + {"name": "Less than boundary", "filter": f"id < {rows_per_fragment}"}, + { + "name": "Greater than boundary", + "filter": f"id > {2 * rows_per_fragment - 1}", + }, + {"name": "Less than or equal", "filter": f"id <= {rows_per_fragment + 50}"}, + {"name": "Greater than or equal", "filter": f"id >= {rows_per_fragment + 50}"}, + ] + + print(f"\nRunning {len(test_cases)} detailed comparison tests:") + + for i, test_case in enumerate(test_cases, 1): + test_name = test_case["name"] + filter_expr = test_case["filter"] + + print(f" {i:2d}. Testing {test_name}: {filter_expr}") + + # Query fragment-based index + fragment_results = fragment_ds_committed.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Query complete index + complete_results = complete_ds.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Compare row counts + assert fragment_results.num_rows == complete_results.num_rows, ( + f"Test '{test_name}' failed: Fragment index " + f"returned {fragment_results.num_rows} rows, " + f"but complete index returned {complete_results.num_rows}" + f" rows for filter: {filter_expr}" + ) + + # Compare actual results if there are any + if fragment_results.num_rows > 0: + # Sort both results by id for comparison + fragment_ids = sorted(fragment_results.column("id").to_pylist()) + complete_ids = sorted(complete_results.column("id").to_pylist()) + + assert fragment_ids == complete_ids, ( + f"Test '{test_name}' failed: Fragment index" + f" returned different IDs than complete index. " + f"Fragment IDs:" + f" {fragment_ids[:10]}{'...' if len(fragment_ids) > 10 else ''}, " + f"Complete IDs:" + f" {complete_ids[:10]}{'...' if len(complete_ids) > 10 else ''}" + ) + + print(f" āœ“ Passed ({fragment_results.num_rows} rows)") + + print(f"\nāœ… All {len(test_cases)} precision tests passed!") + print( + "Fragment-level B-tree index produces identical results" + " to complete B-tree index." + ) + + +def test_btree_fragment_ids_parameter_validation(tmp_path): + """ + Test validation of fragment_ids parameter for B-tree indices. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # Test with valid fragment IDs + fragments = ds.get_fragments() + valid_fragment_id = fragments[0].fragment_id + + # This should work without errors + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[valid_fragment_id], + ) + + # Test with invalid fragment ID (should handle gracefully) + try: + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[999999], # Non-existent fragment ID + ) + except Exception as e: + # It's acceptable for this to fail with an appropriate error + print(f"Expected error for invalid fragment ID: {e}") + + +def test_btree_backward_compatibility_no_fragment_ids(tmp_path): + """ + Test that B-tree indexing remains backward compatible + when fragment_ids is not provided. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # This should work exactly as before (full dataset indexing) + ds.create_scalar_index( + column="id", + index_type="BTREE", + name="full_dataset_btree_idx", + ) + + # Verify the index was created + indices = ds.list_indices() + assert len(indices) == 1 + assert indices[0]["name"] == "full_dataset_btree_idx" + assert indices[0]["type"] == "BTree" + + # Test that the index works + results = ds.scanner(filter="id = 50").to_table() + assert results.num_rows > 0 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 83da7f26dc9..f3d0fd10e83 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1670,47 +1670,111 @@ impl Dataset { .infer_error() } - #[pyo3(signature = (index_uuid))] - fn merge_index_metadata(&self, index_uuid: &str) -> PyResult<()> { + #[pyo3(signature = (index_uuid, index_type, prefetch_batch))] + fn merge_index_metadata( + &self, + index_uuid: &str, + index_type: &str, + prefetch_batch: Option, + ) -> PyResult<()> { RT.block_on(None, async { + let index_type = index_type.to_uppercase(); + let idx_type = match index_type.as_str() { + "BTREE" => IndexType::BTree, + "INVERTED" => IndexType::Inverted, + _ => { + return Err(Error::InvalidInput { + source: format!( + "Index type {} is not supported.", + index_type + ).into(), + location: location!(), + }); + } + }; + let store = LanceIndexStore::from_dataset_for_new(self.ds.as_ref(), index_uuid)?; let index_dir = self.ds.indices_dir().child(index_uuid); + if idx_type == IndexType::Inverted { + // List all partition metadata files in the index directory + let mut part_metadata_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") + { + part_metadata_files.push(file_name.to_string()); + } + } + Err(_) => continue, + } + } + + if part_metadata_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {}", + index_dir + ) + .into(), + location: location!(), + }); + } - // List all partition metadata files in the index directory - let mut part_metadata_files = Vec::new(); - let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); - - while let Some(item) = list_stream.next().await { - match item { - Ok(meta) => { - let file_name = meta.location.filename().unwrap_or_default(); - // Filter files matching the pattern part_*_metadata.lance - if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") - { - part_metadata_files.push(file_name.to_string()); + // Call merge_metadata_files function for inverted index + lance_index::scalar::inverted::builder::merge_metadata_files( + Arc::new(store), + &part_metadata_files, + ) + .await + } else { + // List all partition page / lookup files in the index directory + let mut part_page_files = Vec::new(); + let mut part_lookup_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_page_data.lance") + { + part_page_files.push(file_name.to_string()); + } + if file_name.starts_with("part_") && file_name.ends_with("_page_lookup.lance") + { + part_lookup_files.push(file_name.to_string()); + } } + Err(_) => continue, } - Err(_) => continue, } - } + if part_page_files.is_empty() || part_lookup_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {} (page_files: {}, lookup_files: {})", + index_dir, part_page_files.len(), part_lookup_files.len() + ) + .into(), + location: location!(), + }); + } - if part_metadata_files.is_empty() { - return Err(Error::InvalidInput { - source: format!( - "No partition metadata files found in index directory: {}", - index_dir - ) - .into(), - location: location!(), - }); + // Call merge_metadata_files function for btree index + lance_index::scalar::btree::merge_metadata_files( + Arc::new(store), + &part_page_files, + &part_lookup_files, + prefetch_batch, + ).await } - // Call merge_metadata_files function for inverted index - lance_index::scalar::inverted::builder::merge_metadata_files( - Arc::new(store), - &part_metadata_files, - ) - .await + })? .map_err(|err| PyValueError::new_err(err.to_string())) } diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 1b5d0d530bd..09dde5297b4 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -528,6 +528,7 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_bitmap_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index b2760fe214d..9c3e6685703 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -4,10 +4,10 @@ use std::{ any::Any, cmp::Ordering, - collections::{BTreeMap, BinaryHeap, HashMap}, + collections::{BTreeMap, BinaryHeap, HashMap, VecDeque}, fmt::{Debug, Display}, ops::Bound, - sync::Arc, + sync::{Arc, LazyLock}, }; use super::{ @@ -38,7 +38,7 @@ use deepsize::DeepSizeOf; use futures::{ future::BoxFuture, stream::{self}, - FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, + Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, }; use lance_core::{ cache::{CacheKey, LanceCache}, @@ -54,16 +54,37 @@ use lance_datafusion::{ chunker::chunk_concat_stream, exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; -use log::debug; +use log::{debug, warn}; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize, Serializer}; use snafu::location; +use tokio::runtime::{Builder, Runtime}; use tracing::info; const BTREE_LOOKUP_NAME: &str = "page_lookup.lance"; const BTREE_PAGES_NAME: &str = "page_data.lance"; pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; const BATCH_SIZE_META_KEY: &str = "batch_size"; + +/// Global thread pool for B-tree prefetch operations +static BTREE_PREFETCH_RUNTIME: LazyLock = LazyLock::new(|| { + Builder::new_multi_thread() + .worker_threads(get_num_compute_intensive_cpus()) + .max_blocking_threads(get_num_compute_intensive_cpus()) + .thread_name("lance-btree-prefetch") + .enable_time() + .build() + .expect("Failed to create B-tree prefetch runtime") +}); + +/// Spawn a prefetch task on the B-tree thread pool +fn spawn_btree_prefetch(future: F) -> tokio::task::JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + BTREE_PREFETCH_RUNTIME.spawn(future) +} const BTREE_INDEX_VERSION: u32 = 0; pub(crate) const BTREE_VALUES_COLUMN: &str = "values"; pub(crate) const BTREE_IDS_COLUMN: &str = "ids"; @@ -1231,6 +1252,7 @@ impl ScalarIndex for BTreeIndex { self.sub_index.as_ref(), dest_store, DEFAULT_BTREE_BATCH_SIZE, + None, ) .await?; @@ -1366,10 +1388,33 @@ pub async fn train_btree_index( sub_index_trainer: &dyn BTreeSubIndex, index_store: &dyn IndexStore, batch_size: u64, + fragment_ids: Option>, ) -> Result<()> { - let mut sub_index_file = index_store - .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) - .await?; + let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| { + if !frag_ids.is_empty() { + // Create a mask with fragment_id in high 32 bits for distributed indexing + // This mask is used to filter partitions belonging to specific fragments + // If multiple fragments processed, use first fragment_id <<32 as mask + Some((frag_ids[0] as u64) << 32) + } else { + None + } + }); + + let mut sub_index_file; + if fragment_mask.is_none() { + sub_index_file = index_store + .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) + .await?; + } else { + sub_index_file = index_store + .new_index_file( + part_page_data_file_path(fragment_mask.unwrap()).as_str(), + sub_index_trainer.schema().clone(), + ) + .await?; + } + let mut encoded_batches = Vec::new(); let mut batch_idx = 0; @@ -1393,385 +1438,1945 @@ pub async fn train_btree_index( file_schema .metadata .insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); - let mut btree_index_file = index_store - .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) - .await?; + let mut btree_index_file; + if fragment_mask.is_none() { + btree_index_file = index_store + .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) + .await?; + } else { + btree_index_file = index_store + .new_index_file( + part_lookup_file_path(fragment_mask.unwrap()).as_str(), + Arc::new(file_schema), + ) + .await?; + } btree_index_file.write_record_batch(record_batch).await?; btree_index_file.finish().await?; Ok(()) } -/// A stream that reads the original training data back out of the index -/// -/// This is used for updating the index -struct IndexReaderStream { - reader: Arc, - batch_size: u64, - num_batches: u32, - batch_idx: u32, +/// Extract partition ID from partition file name +/// Expected format: "part_{partition_id}_{suffix}.lance" +fn extract_partition_id(filename: &str) -> Result { + if !filename.starts_with("part_") { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + let parts: Vec<&str> = filename.split('_').collect(); + if parts.len() < 3 { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + parts[1].parse::().map_err(|_| Error::Internal { + message: format!("Failed to parse partition ID from filename: {}", filename), + location: location!(), + }) } -impl IndexReaderStream { - async fn new(reader: Arc, batch_size: u64) -> Self { - let num_batches = reader.num_batches(batch_size).await; - Self { - reader, - batch_size, - num_batches, - batch_idx: 0, +/// Merge multiple partition page / lookup files into a complete metadata file +/// +/// In a distributed environment, each worker node writes partition page / lookup files for the partitions it processes, +/// and this function merges these files into a final metadata file. +pub async fn merge_metadata_files( + store: Arc, + part_page_files: &[String], + part_lookup_files: &[String], + prefetch_batch: Option, +) -> Result<()> { + if part_lookup_files.is_empty() || part_page_files.is_empty() { + return Err(Error::Internal { + message: "No partition files provided for merging".to_string(), + location: location!(), + }); + } + + // Step 1: Create lookup map for page files by partition ID + let mut page_files_map = HashMap::new(); + for page_file in part_page_files { + let partition_id = extract_partition_id(page_file)?; + page_files_map.insert(partition_id, page_file); + } + + // Step 2: Validate that all lookup files have corresponding page files + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + if !page_files_map.contains_key(&partition_id) { + return Err(Error::Internal { + message: format!( + "No corresponding page file found for lookup file: {} (partition_id: {})", + lookup_file, partition_id + ), + location: location!(), + }); } } + + // Step 3: Extract metadata from lookup files + let first_lookup_reader = store.open_index_file(&part_lookup_files[0]).await?; + let batch_size = first_lookup_reader + .schema() + .metadata + .get(BATCH_SIZE_META_KEY) + .map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE)) + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE); + + // Get the value type from lookup schema (min column) + let lookup_batch = first_lookup_reader.read_range(0..1, None).await?; + let value_type = lookup_batch.column(0).data_type().clone(); + + // Get page schema first + let partition_id = extract_partition_id(part_lookup_files[0].as_str())?; + let page_file = page_files_map.get(&partition_id).unwrap(); + let page_reader = store.open_index_file(page_file).await?; + let page_schema = page_reader.schema().clone(); + + let arrow_schema = Arc::new(Schema::from(&page_schema)); + let mut page_file = store + .new_index_file(BTREE_PAGES_NAME, arrow_schema.clone()) + .await?; + + let mut prefetch_config = PrefetchConfig::default(); + if prefetch_batch.is_some() { + prefetch_config = prefetch_config.with_prefetch_batch(prefetch_batch.unwrap()); + } + + let lookup_entries = merge_page( + part_lookup_files, + &page_files_map, + &store, + batch_size, + &mut page_file, + arrow_schema.clone(), + prefetch_config, + ) + .await?; + + page_file.finish().await?; + + // Step 4: Generate new lookup file based on reorganized pages + // Add batch_size to schema metadata + let mut metadata = HashMap::new(); + metadata.insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); + + let lookup_schema_with_metadata = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("min", value_type.clone(), true), + Field::new("max", value_type, true), + Field::new("null_count", DataType::UInt32, false), + Field::new("page_idx", DataType::UInt32, false), + ], + metadata, + )); + + let lookup_batch = RecordBatch::try_new( + lookup_schema_with_metadata.clone(), + vec![ + ScalarValue::iter_to_array(lookup_entries.iter().map(|(min, _, _, _)| min.clone()))?, + ScalarValue::iter_to_array(lookup_entries.iter().map(|(_, max, _, _)| max.clone()))?, + Arc::new(UInt32Array::from_iter_values( + lookup_entries + .iter() + .map(|(_, _, null_count, _)| *null_count), + )), + Arc::new(UInt32Array::from_iter_values( + lookup_entries.iter().map(|(_, _, _, page_idx)| *page_idx), + )), + ], + )?; + + let mut lookup_file = store + .new_index_file(BTREE_LOOKUP_NAME, lookup_schema_with_metadata) + .await?; + lookup_file.write_record_batch(lookup_batch).await?; + lookup_file.finish().await?; + + // After successfully writing the merged files, delete all partition files + // Only perform deletion after files are successfully written, ensuring debug information is not lost in case of failure + cleanup_partition_files(&store, part_lookup_files, part_page_files).await; + + Ok(()) } -impl Stream for IndexReaderStream { - type Item = BoxFuture<'static, Result>; +/// Clean up partition files after successful merge +/// +/// This function safely deletes partition lookup and page files after a successful merge operation. +/// File deletion failures are logged but do not affect the overall success of the merge operation. +async fn cleanup_partition_files( + store: &Arc, + part_lookup_files: &[String], + part_page_files: &[String], +) { + // Clean up partition lookup files + for file_name in part_lookup_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_lookup.lance", + "partition lookup", + ) + .await; + } - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - if this.batch_idx >= this.num_batches { - return std::task::Poll::Ready(None); - } - let batch_num = this.batch_idx; - this.batch_idx += 1; - let reader_copy = this.reader.clone(); - let batch_size = this.batch_size; - let read_task = async move { - reader_copy - .read_record_batch(batch_num as u64, batch_size) - .await - } - .boxed(); - std::task::Poll::Ready(Some(read_task)) + // Clean up partition page files + for file_name in part_page_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_data.lance", + "partition page", + ) + .await; } } -/// Parameters for a btree index -#[derive(Debug, Serialize, Deserialize)] -pub struct BTreeParameters { - /// The number of rows to include in each zone - pub zone_size: Option, +/// Helper function to clean up a single partition file +/// +/// Performs safety checks on the filename pattern before attempting deletion. +async fn cleanup_single_file( + store: &Arc, + file_name: &str, + expected_prefix: &str, + expected_suffix: &str, + file_type: &str, +) { + // Ensure we only delete files that match the expected pattern (safety check) + if file_name.starts_with(expected_prefix) && file_name.ends_with(expected_suffix) { + match store.delete_index_file(file_name).await { + Ok(()) => { + debug!("Successfully deleted {} file: {}", file_type, file_name); + } + Err(e) => { + // File deletion failures should not affect the overall success of the function + // Log the error but continue processing other files + warn!( + "Failed to delete {} file '{}': {}. \ + This does not affect the merge operation, but may leave \ + partition files that should be cleaned up manually.", + file_type, file_name, e + ); + } + } + } else { + // If the filename doesn't match the expected format, log a warning but don't attempt deletion + warn!( + "Skipping deletion of file '{}' as it does not match the expected \ + {} file pattern ({}*{})", + file_name, file_type, expected_prefix, expected_suffix + ); + } } -struct BTreeTrainingRequest { - parameters: BTreeParameters, - criteria: TrainingCriteria, +/// Prefetch configuration for partition iterators +#[derive(Debug, Clone)] +pub struct PrefetchConfig { + /// Number of batches to prefetch ahead (0 means no prefetching) + pub prefetch_batches: usize, } -impl BTreeTrainingRequest { - pub fn new(parameters: BTreeParameters) -> Self { +impl Default for PrefetchConfig { + fn default() -> Self { Self { - parameters, - // BTree indexes need data sorted by the value column - criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + prefetch_batches: 1, } } } -impl TrainingRequest for BTreeTrainingRequest { - fn as_any(&self) -> &dyn std::any::Any { - self +impl PrefetchConfig { + /// Set the prefetch batch count + pub fn with_prefetch_batch(&self, batch_count: usize) -> Self { + Self { + prefetch_batches: batch_count, + } } +} - fn criteria(&self) -> &TrainingCriteria { - &self.criteria - } +/// Buffer entry for prefetch queue +#[derive(Debug)] +struct BufferEntry { + batch: RecordBatch, + start_row: usize, + end_row: usize, } -#[derive(Debug, Default)] -pub struct BTreeIndexPlugin; +/// Running prefetch task information +#[derive(Debug)] +struct RunningPrefetchTask { + /// Task handle + handle: tokio::task::JoinHandle<()>, + /// Range being prefetched + range: std::ops::Range, +} -#[async_trait] -impl ScalarIndexPlugin for BTreeIndexPlugin { - fn new_training_request( - &self, - params: &str, - field: &Field, - ) -> Result> { - if field.data_type().is_nested() { - return Err(Error::InvalidInput { - source: "A btree index can only be created on a non-nested field.".into(), - location: location!(), - }); - } +/// Check if two ranges overlap +fn ranges_overlap(range1: &std::ops::Range, range2: &std::ops::Range) -> bool { + range1.start < range2.end && range2.start < range1.end +} - let params = serde_json::from_str::(params)?; - Ok(Box::new(BTreeTrainingRequest::new(params))) - } +/// Prefetch state for a partition using task-based prefetching +struct PartitionPrefetchState { + /// Queue of prefetched data + buffer: Arc>>, + /// Reader for this partition + reader: Arc, + /// Total rows in this partition + total_rows: usize, + /// Queue of running prefetch tasks with their ranges + running_tasks: Arc>>, + /// Next position to schedule for prefetch + next_prefetch_position: Arc>, +} - fn provides_exact_answer(&self) -> bool { - true - } +/// Manager for coordinating task-based prefetch across multiple partitions +pub struct PrefetchManager { + /// Prefetch state per partition + partition_states: HashMap, + /// Prefetch configuration + config: PrefetchConfig, +} - fn version(&self) -> u32 { - 0 +impl PrefetchManager { + /// Create a new prefetch manager + pub fn new(config: PrefetchConfig) -> Self { + Self { + partition_states: HashMap::new(), + config, + } } - fn new_query_parser( - &self, - index_name: String, - _index_details: &prost_types::Any, - ) -> Option> { - Some(Box::new(SargableQueryParser::new(index_name, false))) - } + /// Initialize a partition for task-based prefetching + pub fn initialize_partition(&mut self, partition_id: u64, reader: Arc) { + let total_rows = reader.num_rows(); + let buffer = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let running_tasks = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let next_prefetch_position = Arc::new(tokio::sync::Mutex::new(0)); - async fn train_index( - &self, - data: SendableRecordBatchStream, - index_store: &dyn IndexStore, - request: Box, - ) -> Result { - let request = request - .as_any() - .downcast_ref::() - .unwrap(); - let value_type = data - .schema() - .field_with_name(VALUE_COLUMN_NAME)? - .data_type() - .clone(); - let flat_index_trainer = FlatIndexMetadata::new(value_type); - train_btree_index( - data, - &flat_index_trainer, - index_store, - request - .parameters - .zone_size - .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), - ) - .await?; - Ok(CreatedIndex { - index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), - index_version: BTREE_INDEX_VERSION, - }) - } + let state = PartitionPrefetchState { + buffer, + reader, + total_rows, + running_tasks, + next_prefetch_position, + }; - async fn load_index( - &self, - index_store: Arc, - _index_details: &prost_types::Any, - frag_reuse_index: Option>, - cache: LanceCache, - ) -> Result> { - Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + self.partition_states.insert(partition_id, state); + debug!( + "Initialized partition {} for task-based prefetching", + partition_id + ); } -} -#[cfg(test)] -mod tests { - use std::sync::atomic::Ordering; - use std::{collections::HashMap, sync::Arc}; + /// Submit a prefetch task for a partition to the thread pool + pub async fn submit_prefetch_task(&self, partition_id: u64, batch_size: usize) -> Result<()> { + if self.config.prefetch_batches == 0 { + return Ok(()); + } - use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; - use arrow_array::FixedSizeListArray; - use arrow_schema::DataType; - use datafusion::{ - execution::{SendableRecordBatchStream, TaskContext}, - physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, - }; - use datafusion_common::{DataFusionError, ScalarValue}; - use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; - use deepsize::DeepSizeOf; - use futures::TryStreamExt; - use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; - use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; - use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; - use lance_io::object_store::ObjectStore; - use object_store::path::Path; - use tempfile::tempdir; + let Some(state) = self.partition_states.get(&partition_id) else { + return Ok(()); + }; - use crate::metrics::LocalMetricsCollector; - use crate::{ - metrics::NoOpMetricsCollector, - scalar::{ - btree::{BTreeIndex, BTREE_PAGES_NAME}, - flat::FlatIndexMetadata, - lance_format::LanceIndexStore, - IndexStore, SargableQuery, ScalarIndex, SearchResult, - }, - }; + let reader = state.reader.clone(); + let buffer = state.buffer.clone(); + let running_tasks = state.running_tasks.clone(); + let next_prefetch_position = state.next_prefetch_position.clone(); + let total_rows = state.total_rows; + let effective_batch_size = self.config.prefetch_batches * batch_size; - use super::{train_btree_index, OrderableScalarValue}; + const MAX_BUFFER_SIZE: usize = 4; + const MAX_RUNNING_TASKS: usize = 2; - #[test] - fn test_scalar_value_size() { - let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); - let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( - FixedSizeListArray::from_iter_primitive::( - vec![Some(vec![Some(0); 128])], - 128, - ), - ))) - .deep_size_of(); + // Clean up completed tasks and check limits + { + let mut tasks_guard = running_tasks.lock().await; - // deep_size_of should account for the rust type overhead - assert!(size_of_i32 > 4); - assert!(size_of_many_i32 > 128 * 4); + // Remove completed tasks from the front + while let Some(task) = tasks_guard.front() { + if task.handle.is_finished() { + tasks_guard.pop_front(); + } else { + break; + } + } + + // Check if we have too many running tasks + if tasks_guard.len() >= MAX_RUNNING_TASKS { + debug!( + "Skipping prefetch for partition {} - too many running tasks ({})", + partition_id, + tasks_guard.len() + ); + return Ok(()); + } + + // Check if any running task already covers to the end of file + for task in tasks_guard.iter() { + if task.range.end >= total_rows { + debug!( + "Skipping prefetch for partition {} - task already covers to EOF (range {}..{})", + partition_id, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // Check if buffer is full + { + let buffer_guard = buffer.lock().await; + if buffer_guard.len() >= MAX_BUFFER_SIZE { + debug!( + "Skipping prefetch for partition {} - buffer full", + partition_id + ); + return Ok(()); + } + } + + // Determine the next range to prefetch + let next_range = { + let mut pos_guard = next_prefetch_position.lock().await; + let start_pos = *pos_guard; + + if start_pos >= total_rows { + debug!( + "Skipping prefetch for partition {} - no more data to prefetch", + partition_id + ); + return Ok(()); + } + + let end_pos = std::cmp::min(start_pos + effective_batch_size, total_rows); + *pos_guard = end_pos; // Update next prefetch position + start_pos..end_pos + }; + + // Check if this range is already being prefetched + { + let tasks_guard = running_tasks.lock().await; + + // Check for range overlap + for task in tasks_guard.iter() { + if ranges_overlap(&task.range, &next_range) { + debug!( + "Skipping prefetch for partition {} - range {}..{} overlaps with running task {}..{}", + partition_id, next_range.start, next_range.end, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // All checks passed, create the actual prefetch task (only this part is async) + let range_clone = next_range.clone(); + let running_tasks_for_cleanup = running_tasks.clone(); + + let prefetch_task = spawn_btree_prefetch(async move { + // Perform the actual read + match reader.read_range(range_clone.clone(), None).await { + Ok(batch) => { + let entry = BufferEntry { + batch, + start_row: range_clone.start, + end_row: range_clone.end, + }; + + // Add to buffer + { + let mut buffer_guard = buffer.lock().await; + buffer_guard.push_back(entry); + } + + debug!( + "Prefetched {} rows ({}..{}) for partition {}", + range_clone.end - range_clone.start, + range_clone.start, + range_clone.end, + partition_id + ); + } + Err(err) => { + warn!( + "Prefetch task failed for partition {} range {}..{}: {}", + partition_id, range_clone.start, range_clone.end, err + ); + } + } + + // Remove this task from running tasks when completed + { + let mut tasks_guard = running_tasks_for_cleanup.lock().await; + tasks_guard.retain(|task| !task.handle.is_finished()); + } + }); + + // Add the task to running tasks + { + let mut tasks_guard = running_tasks.lock().await; + tasks_guard.push_back(RunningPrefetchTask { + handle: prefetch_task, + range: next_range.clone(), + }); + } + + debug!( + "Submitted prefetch task for partition {} range {}..{}", + partition_id, next_range.start, next_range.end + ); + + Ok(()) + } + + /// Get data from buffer or fallback to direct read + pub async fn get_data_with_fallback( + &self, + partition_id: u64, + start_row: usize, + end_row: usize, + ) -> Result { + if let Some(state) = self.partition_states.get(&partition_id) { + // First try to get from buffer + { + let mut buffer_guard = state.buffer.lock().await; + + // Remove outdated entries from the front + while let Some(entry) = buffer_guard.front() { + if entry.end_row <= start_row { + buffer_guard.pop_front(); + } else { + break; + } + } + + // Check if we have suitable data in buffer + if let Some(entry) = buffer_guard.front() { + if entry.start_row <= start_row && entry.end_row >= end_row { + // Found matching data, extract it + let entry = buffer_guard.pop_front().unwrap(); + drop(buffer_guard); + + let slice_start = start_row - entry.start_row; + let slice_len = end_row - start_row; + + debug!( + "Using buffered data for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + return Ok(entry.batch.slice(slice_start, slice_len)); + } + } + } + + // Fallback to direct read + debug!( + "Direct read fallback for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + state.reader.read_range(start_row..end_row, None).await + } else { + Err(Error::Internal { + message: format!("Partition {} not found in prefetch manager", partition_id), + location: location!(), + }) + } + } +} + +/// Simplified partition iterator with immediate loading since all partitions need to be accessed +struct PartitionIterator { + reader: Arc, + current_batch: Option, + current_position: usize, + rows_read: usize, + partition_id: u64, + batch_size: u64, +} + +impl PartitionIterator { + async fn new( + store: Arc, + page_file_name: String, + partition_id: u64, + batch_size: u64, + ) -> Result { + let reader = store.open_index_file(&page_file_name).await?; + Ok(Self { + reader, + current_batch: None, + current_position: 0, + rows_read: 0, + partition_id, + batch_size, + }) + } + + /// Get the next element, working with the prefetch manager + async fn next( + &mut self, + prefetch_manager: &PrefetchManager, + ) -> Result> { + // Load new batch if current one is exhausted + if self.needs_new_batch() { + if self.rows_read >= self.reader.num_rows() { + return Ok(None); + } + self.load_next_batch(prefetch_manager).await?; + + // Submit next prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } else { + // Check if we've read half of the current batch, submit next prefetch task + let batch_half = self.current_batch.as_ref().unwrap().num_rows() / 2; + if self.current_position == batch_half && batch_half > 0 { + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } + } + + // Extract next value from current batch + if let Some(batch) = &self.current_batch { + let value = ScalarValue::try_from_array(batch.column(0), self.current_position)?; + let row_id = ScalarValue::try_from_array(batch.column(1), self.current_position)?; + self.current_position += 1; + self.rows_read += 1; + Ok(Some((value, row_id))) + } else { + Ok(None) + } + } + + /// Check if we need to load a new batch + fn needs_new_batch(&self) -> bool { + self.current_batch.is_none() + || self.current_position >= self.current_batch.as_ref().unwrap().num_rows() + } + + async fn load_next_batch(&mut self, prefetch_manager: &PrefetchManager) -> Result<()> { + let remaining_rows = self.reader.num_rows() - self.rows_read; + if remaining_rows == 0 { + self.current_batch = None; + return Ok(()); + } + + let rows_to_read = std::cmp::min(self.batch_size as usize, remaining_rows); + let end_row = self.rows_read + rows_to_read; + + // Use the new fallback mechanism - try buffer first, then direct read + let batch = prefetch_manager + .get_data_with_fallback(self.partition_id, self.rows_read, end_row) + .await?; + + self.current_batch = Some(batch); + self.current_position = 0; + + Ok(()) + } + + fn get_reader(&self) -> Arc { + self.reader.clone() + } +} + +/// Heap elements, used for priority queues in multi-way merging +#[derive(Debug)] +struct HeapElement { + value: ScalarValue, + row_id: ScalarValue, + partition_id: u64, +} + +impl PartialEq for HeapElement { + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } +} + +impl Eq for HeapElement {} + +impl PartialOrd for HeapElement { + fn partial_cmp(&self, other: &Self) -> Option { + // Note: BinaryHeap is a maximum heap, we need a minimum heap, + // so reverse the comparison result + other.value.partial_cmp(&self.value) + } +} + +impl Ord for HeapElement { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap_or(Ordering::Equal) + } +} + +async fn merge_page( + part_lookup_files: &[String], + page_files_map: &HashMap, + store: &Arc, + batch_size: u64, + page_file: &mut Box, + arrow_schema: Arc, + prefetch_config: PrefetchConfig, +) -> Result> { + let mut lookup_entries = Vec::new(); + let mut page_idx = 0u32; + + debug!( + "Starting multi-way merge with {} partitions using prefetch manager", + part_lookup_files.len() + ); + + // Create prefetch manager + let mut prefetch_manager = PrefetchManager::new(prefetch_config.clone()); + + // Directly create iterators and read first element + let mut partition_map = HashMap::new(); + let mut heap = BinaryHeap::new(); + + debug!("Initializing {} partitions", part_lookup_files.len()); + + // Initialize all partitions + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + let page_file_name = page_files_map + .get(&partition_id) + .ok_or_else(|| Error::Internal { + message: format!("Page file not found for partition ID: {}", partition_id), + location: location!(), + })? + .to_string(); + + let mut iterator = + PartitionIterator::new(store.clone(), page_file_name, partition_id, batch_size).await?; + + // Initialize partition in prefetch manager + let reader = iterator.get_reader(); + prefetch_manager.initialize_partition(partition_id, reader); + + // Submit initial prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(partition_id, batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + partition_id, err + ); + } + + let first_element = iterator.next(&prefetch_manager).await?; + + if let Some((value, row_id)) = first_element { + // Put the first element into the heap + heap.push(HeapElement { + value, + row_id, + partition_id, + }); + } + + partition_map.insert(partition_id, iterator); + } + + debug!( + "Initialized {} partitions, heap size: {}", + partition_map.len(), + heap.len() + ); + + let mut current_batch_rows = Vec::with_capacity(batch_size as usize); + let mut total_merged = 0usize; + + // Multi-way merge main loop + while let Some(min_element) = heap.pop() { + // Add current minimum element to batch + current_batch_rows.push((min_element.value, min_element.row_id)); + total_merged += 1; + + // Read next element from corresponding partition + if let Some(iterator) = partition_map.get_mut(&min_element.partition_id) { + if let Some((next_value, next_row_id)) = iterator.next(&prefetch_manager).await? { + heap.push(HeapElement { + value: next_value, + row_id: next_row_id, + partition_id: min_element.partition_id, + }); + } + } + + // Write when batch reaches specified size + if current_batch_rows.len() >= batch_size as usize { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + } + + // Write the remaining data + if !current_batch_rows.is_empty() { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + + debug!( + "Completed multi-way merge: merged {} rows into {} lookup entries", + total_merged, + lookup_entries.len() + ); + Ok(lookup_entries) +} + +/// Helper function to prepare batch data in parallel +async fn prepare_batch_data( + batch_rows: Vec<(ScalarValue, ScalarValue)>, + arrow_schema: Arc, + page_idx: u32, +) -> Result<(RecordBatch, (ScalarValue, ScalarValue, u32, u32))> { + if batch_rows.is_empty() { + return Err(Error::Internal { + message: "Cannot prepare empty batch".to_string(), + location: location!(), + }); + } + + // Parallelize data preparation + let (values, row_ids): (Vec<_>, Vec<_>) = batch_rows.into_iter().unzip(); + + // Convert to arrays in parallel using rayon or manually spawn tasks + let values_array = ScalarValue::iter_to_array(values.into_iter())?; + let row_ids_array = ScalarValue::iter_to_array(row_ids.into_iter())?; + + let batch = RecordBatch::try_new(arrow_schema, vec![values_array, row_ids_array])?; + + // Calculate min/max/null_count for lookup entry + let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; + let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; + let null_count = batch.column(0).null_count() as u32; + + let lookup_entry = (min_val, max_val, null_count, page_idx); + + Ok((batch, lookup_entry)) +} + +/// Helper function to write a batch and create lookup entry +async fn write_batch_and_lookup_entry( + batch_rows: &mut Vec<(ScalarValue, ScalarValue)>, + page_file: &mut Box, + arrow_schema: &Arc, + lookup_entries: &mut Vec<(ScalarValue, ScalarValue, u32, u32)>, + page_idx: &mut u32, +) -> Result<()> { + if batch_rows.is_empty() { + return Ok(()); + } + + // Take ownership of the batch data + let batch_data = std::mem::take(batch_rows); + let current_page_idx = *page_idx; + + // Prepare batch data + let (batch, lookup_entry) = + prepare_batch_data(batch_data, arrow_schema.clone(), current_page_idx).await?; + + lookup_entries.push(lookup_entry); + page_file.write_record_batch(batch).await?; + *page_idx += 1; + + Ok(()) +} + +pub(crate) fn part_page_data_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_PAGES_NAME) +} + +pub(crate) fn part_lookup_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME) +} + +/// A stream that reads the original training data back out of the index +/// +/// This is used for updating the index +struct IndexReaderStream { + reader: Arc, + batch_size: u64, + num_batches: u32, + batch_idx: u32, +} + +impl IndexReaderStream { + async fn new(reader: Arc, batch_size: u64) -> Self { + let num_batches = reader.num_batches(batch_size).await; + Self { + reader, + batch_size, + num_batches, + batch_idx: 0, + } + } +} + +impl Stream for IndexReaderStream { + type Item = BoxFuture<'static, Result>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + if this.batch_idx >= this.num_batches { + return std::task::Poll::Ready(None); + } + let batch_num = this.batch_idx; + this.batch_idx += 1; + let reader_copy = this.reader.clone(); + let batch_size = this.batch_size; + let read_task = async move { + reader_copy + .read_record_batch(batch_num as u64, batch_size) + .await + } + .boxed(); + std::task::Poll::Ready(Some(read_task)) + } +} + +/// Parameters for a btree index +#[derive(Debug, Serialize, Deserialize)] +pub struct BTreeParameters { + /// The number of rows to include in each zone + pub zone_size: Option, +} + +struct BTreeTrainingRequest { + parameters: BTreeParameters, + criteria: TrainingCriteria, +} + +impl BTreeTrainingRequest { + pub fn new(parameters: BTreeParameters) -> Self { + Self { + parameters, + // BTree indexes need data sorted by the value column + criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + } + } +} + +impl TrainingRequest for BTreeTrainingRequest { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn criteria(&self) -> &TrainingCriteria { + &self.criteria + } +} + +#[derive(Debug, Default)] +pub struct BTreeIndexPlugin; + +#[async_trait] +impl ScalarIndexPlugin for BTreeIndexPlugin { + fn new_training_request( + &self, + params: &str, + field: &Field, + ) -> Result> { + if field.data_type().is_nested() { + return Err(Error::InvalidInput { + source: "A btree index can only be created on a non-nested field.".into(), + location: location!(), + }); + } + + let params = serde_json::from_str::(params)?; + Ok(Box::new(BTreeTrainingRequest::new(params))) + } + + fn provides_exact_answer(&self) -> bool { + true + } + + fn version(&self) -> u32 { + 0 + } + + fn new_query_parser( + &self, + index_name: String, + _index_details: &prost_types::Any, + ) -> Option> { + Some(Box::new(SargableQueryParser::new(index_name, false))) + } + + async fn train_index( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + ) -> Result { + let request = request + .as_any() + .downcast_ref::() + .unwrap(); + let value_type = data + .schema() + .field_with_name(VALUE_COLUMN_NAME)? + .data_type() + .clone(); + let flat_index_trainer = FlatIndexMetadata::new(value_type); + train_btree_index( + data, + &flat_index_trainer, + index_store, + request + .parameters + .zone_size + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), + fragment_ids, + ) + .await?; + Ok(CreatedIndex { + index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), + index_version: BTREE_INDEX_VERSION, + }) + } + + async fn load_index( + &self, + index_store: Arc, + _index_details: &prost_types::Any, + frag_reuse_index: Option>, + cache: LanceCache, + ) -> Result> { + Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + use std::{collections::HashMap, sync::Arc}; + + use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; + use arrow_array::FixedSizeListArray; + use arrow_schema::DataType; + use datafusion::{ + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, + }; + use datafusion_common::{DataFusionError, ScalarValue}; + use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; + use deepsize::DeepSizeOf; + use futures::TryStreamExt; + use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; + use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; + use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use tempfile::tempdir; + + use crate::metrics::LocalMetricsCollector; + use crate::{ + metrics::NoOpMetricsCollector, + scalar::{ + btree::{BTreeIndex, BTREE_PAGES_NAME}, + flat::FlatIndexMetadata, + lance_format::LanceIndexStore, + IndexStore, SargableQuery, ScalarIndex, SearchResult, + }, + }; + + use super::{ + part_lookup_file_path, part_page_data_file_path, train_btree_index, OrderableScalarValue, + DEFAULT_BTREE_BATCH_SIZE, + }; + + #[test] + fn test_scalar_value_size() { + let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); + let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(0); 128])], + 128, + ), + ))) + .deep_size_of(); + + // deep_size_of should account for the rust type overhead + assert!(size_of_i32 > 4); + assert!(size_of_many_i32 > 128 * 4); + } + + #[tokio::test] + async fn test_null_ids() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Generate 50,000 rows of random data with 80% nulls + let stream = gen_batch() + .col( + "value", + array::rand::().with_nulls(&[true, false, false, false, false]), + ) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(5000), BatchCount::from(10)); + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(index.page_lookup.null_pages.len(), 10); + + let remap_dir = Arc::new(tempdir().unwrap()); + let remap_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(remap_dir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Remap with a no-op mapping. The remapped index should be identical to the original + index + .remap(&HashMap::default(), remap_store.as_ref()) + .await + .unwrap(); + + let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(remap_index.page_lookup, index.page_lookup); + + let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + + assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + + let original_data = original_pages + .read_record_batch(0, original_pages.num_rows() as u64) + .await + .unwrap(); + let remapped_data = remapped_pages + .read_record_batch(0, remapped_pages.num_rows() as u64) + .await + .unwrap(); + + assert_eq!(original_data, remapped_data); + } + + #[tokio::test] + async fn test_nan_ordering() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let values = vec![ + 0.0, + 1.0, + 2.0, + 3.0, + f64::NAN, + f64::NEG_INFINITY, + f64::INFINITY, + ]; + + // This is a bit overkill but we've had bugs in the past where DF's sort + // didn't agree with Arrow's sort so we do an end-to-end test here + // and use DF to sort the data like we would in a real dataset. + let data = gen_batch() + .col("value", array::cycle::(values.clone())) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(10), BatchCount::from(100)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + .await + .unwrap(); + + for (idx, value) in values.into_iter().enumerate() { + let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert_eq!( + result, + SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) + ); + } + } + + #[tokio::test] + async fn test_page_cache() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let data = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(1000), BatchCount::from(10)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load( + test_store, + None, + LanceCache::with_capacity(100 * 1024 * 1024), + ) + .await + .unwrap(); + + let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); + let metrics = LocalMetricsCollector::default(); + let query1 = index.search(&query, &metrics); + let query2 = index.search(&query, &metrics); + tokio::join!(query1, query2).0.unwrap(); + assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + } + + /// Test that fragment-based btree index construction produces exactly the same results as building a complete index + #[tokio::test] + async fn test_fragment_btree_index_consistency() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Method 1: Build complete index directly using the same data + // Create deterministic data for comparison - use 2 * DEFAULT_BTREE_BATCH_SIZE for testing + let total_count = (2 * DEFAULT_BTREE_BATCH_SIZE) as u64; + let full_data_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(total_count / 2), BatchCount::from(2)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using the same data split into fragments + // Create fragment 1 index - first half of the data (0 to DEFAULT_BTREE_BATCH_SIZE-1) + let half_count = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), // fragment_id = 1 + ) + .await + .unwrap(); + + // Create fragment 2 index - second half of the data (DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1) + let start_val = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_second_half: Vec = (start_val..end_val).collect(); + let row_ids_second_half: Vec = (start_val as u64..end_val as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_second_half)) + .col("_rowid", array::cycle::(row_ids_second_half)) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), // fragment_id = 2 + ) + .await + .unwrap(); + + // Merge the fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + // Test queries one by one to identify the exact problem + + // Test 1: Query for value 0 (should be in first page) + let query_0 = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_0 = full_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_0 = merged_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(full_result_0, merged_result_0, "Query for value 0 failed"); + + // Test 2: Query for value in middle of first batch (should be in first page) + let mid_first_batch = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_first = SargableQuery::Equals(ScalarValue::Int32(Some(mid_first_batch))); + let full_result_mid_first = full_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_first = merged_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_first, merged_result_mid_first, + "Query for value {} failed", + mid_first_batch + ); + + // Test 3: Query for first value in second batch (should be in second page) + let first_second_batch = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_first_second = + SargableQuery::Equals(ScalarValue::Int32(Some(first_second_batch))); + let full_result_first_second = full_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_first_second = merged_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_first_second, merged_result_first_second, + "Query for value {} failed", + first_second_batch + ); + + // Test 4: Query for value in middle of second batch (should be in second page) + let mid_second_batch = (DEFAULT_BTREE_BATCH_SIZE + DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_second = SargableQuery::Equals(ScalarValue::Int32(Some(mid_second_batch))); + + let full_result_mid_second = full_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_second = merged_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_second, merged_result_mid_second, + "Query for value {} failed", + mid_second_batch + ); } #[tokio::test] - async fn test_null_ids() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + async fn test_fragment_btree_index_boundary_queries() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - // Generate 50,000 rows of random data with 80% nulls - let stream = gen_batch() - .col( - "value", - array::rand::().with_nulls(&[true, false, false, false, false]), - ) + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Use 3 * DEFAULT_BTREE_BATCH_SIZE for more comprehensive boundary testing + let total_count = (3 * DEFAULT_BTREE_BATCH_SIZE) as u64; + + // Method 1: Build complete index directly + let full_data_gen = gen_batch() + .col("value", array::step::()) .col("_rowid", array::step::()) - .into_df_stream(RowCount::from(5000), BatchCount::from(10)); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + .into_df_stream(RowCount::from(total_count / 3), BatchCount::from(3)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using 3 fragments + // Fragment 1: 0 to DEFAULT_BTREE_BATCH_SIZE-1 + let fragment_size = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), + ) + .await + .unwrap(); + + // Fragment 2: DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val2 = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val2 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment2: Vec = (start_val2..end_val2).collect(); + let row_ids_fragment2: Vec = (start_val2 as u64..end_val2 as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_fragment2)) + .col("_rowid", array::cycle::(row_ids_fragment2)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), + ) + .await + .unwrap(); - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000) + // Fragment 3: 2*DEFAULT_BTREE_BATCH_SIZE to 3*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val3 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let end_val3 = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment3: Vec = (start_val3..end_val3).collect(); + let row_ids_fragment3: Vec = (start_val3 as u64..end_val3 as u64).collect(); + let fragment3_gen = gen_batch() + .col("value", array::cycle::(values_fragment3)) + .col("_rowid", array::cycle::(row_ids_fragment3)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment3_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment3_gen.schema(), + fragment3_gen, + )); + + train_btree_index( + fragment3_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![3]), + ) + .await + .unwrap(); + + // Merge all fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + part_page_data_file_path(3 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + part_lookup_file_path(3 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - assert_eq!(index.page_lookup.null_pages.len(), 10); + // === Boundary Value Tests === - let remap_dir = Arc::new(tempdir().unwrap()); - let remap_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(remap_dir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 1: Query minimum value (boundary: data start) + let query_min = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_min = full_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_min = merged_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_min, merged_result_min, + "Query for minimum value 0 failed" + ); - // Remap with a no-op mapping. The remapped index should be identical to the original - index - .remap(&HashMap::default(), remap_store.as_ref()) + // Test 2: Query maximum value (boundary: data end) + let max_val = (3 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val))); + let full_result_max = full_index + .search(&query_max, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_max = merged_index + .search(&query_max, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_max, merged_result_max, + "Query for maximum value {} failed", + max_val + ); - let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + // Test 3: Query fragment boundary value (last value of first fragment) + let fragment1_last = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag1_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment1_last))); + let full_result_frag1_last = full_index + .search(&query_frag1_last, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_frag1_last = merged_index + .search(&query_frag1_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag1_last, merged_result_frag1_last, + "Query for fragment 1 last value {} failed", + fragment1_last + ); - assert_eq!(remap_index.page_lookup, index.page_lookup); + // Test 4: Query fragment boundary value (first value of second fragment) + let fragment2_first = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_frag2_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_first))); + let full_result_frag2_first = full_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_first = merged_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_first, merged_result_frag2_first, + "Query for fragment 2 first value {} failed", + fragment2_first + ); - let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); - let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + // Test 5: Query fragment boundary value (last value of second fragment) + let fragment2_last = (2 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag2_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_last))); + let full_result_frag2_last = full_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_last = merged_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_last, merged_result_frag2_last, + "Query for fragment 2 last value {} failed", + fragment2_last + ); - assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + // Test 6: Query fragment boundary value (first value of third fragment) + let fragment3_first = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_frag3_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment3_first))); + let full_result_frag3_first = full_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag3_first = merged_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag3_first, merged_result_frag3_first, + "Query for fragment 3 first value {} failed", + fragment3_first + ); - let original_data = original_pages - .read_record_batch(0, original_pages.num_rows() as u64) + // === Non-existent Value Tests === + + // Test 7: Query value below minimum + let query_below_min = SargableQuery::Equals(ScalarValue::Int32(Some(-1))); + let full_result_below = full_index + .search(&query_below_min, &NoOpMetricsCollector) .await .unwrap(); - let remapped_data = remapped_pages - .read_record_batch(0, remapped_pages.num_rows() as u64) + let merged_result_below = merged_index + .search(&query_below_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_below, merged_result_below, + "Query for value below minimum (-1) failed" + ); + + // Test 8: Query value above maximum + let query_above_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val + 1))); + let full_result_above = full_index + .search(&query_above_max, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_above = merged_index + .search(&query_above_max, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_above, + merged_result_above, + "Query for value above maximum ({}) failed", + max_val + 1 + ); - assert_eq!(original_data, remapped_data); - } + // === Range Query Tests === - #[tokio::test] - async fn test_nan_ordering() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 9: Cross-fragment range query (from first fragment to second fragment) + let range_start = (DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let range_end = (DEFAULT_BTREE_BATCH_SIZE + 100) as i32; + let query_cross_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(range_end))), + ); + let full_result_cross = full_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_cross = merged_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_cross, merged_result_cross, + "Cross-fragment range query [{}, {}] failed", + range_start, range_end + ); - let values = vec![ - 0.0, - 1.0, - 2.0, - 3.0, - f64::NAN, - f64::NEG_INFINITY, - f64::INFINITY, - ]; + // Test 10: Range query within single fragment + let single_frag_start = 100i32; + let single_frag_end = 200i32; + let query_single_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(single_frag_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(single_frag_end))), + ); + let full_result_single = full_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_single = merged_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_single, merged_result_single, + "Single fragment range query [{}, {}] failed", + single_frag_start, single_frag_end + ); - // This is a bit overkill but we've had bugs in the past where DF's sort - // didn't agree with Arrow's sort so we do an end-to-end test here - // and use DF to sort the data like we would in a real dataset. - let data = gen_batch() - .col("value", array::cycle::(values.clone())) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(10), BatchCount::from(100)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + // Test 11: Large range query spanning all fragments + let large_range_start = 100i32; + let large_range_end = (3 * DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let query_large_range = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(large_range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(large_range_end))), + ); + let full_result_large = full_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_large = merged_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_large, merged_result_large, + "Large range query [{}, {}] failed", + large_range_start, large_range_end + ); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + // === Range Boundary Query Tests === - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) + // Test 12: Less than query (implemented using range query, from minimum to specified value) + let lt_val = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_lt = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(lt_val))), + ); + let full_result_lt = full_index + .search(&query_lt, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_lt = merged_index + .search(&query_lt, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lt, merged_result_lt, + "Less than query (<{}) failed", + lt_val + ); - let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + // Test 13: Greater than query (implemented using range query, from specified value to maximum) + let gt_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let max_range_val = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gt = SargableQuery::Range( + std::collections::Bound::Excluded(ScalarValue::Int32(Some(gt_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gt = full_index + .search(&query_gt, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gt = merged_index + .search(&query_gt, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_gt, merged_result_gt, + "Greater than query (>{}) failed", + gt_val + ); - for (idx, value) in values.into_iter().enumerate() { - let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); - let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) - ); - } + // Test 14: Less than or equal query (implemented using range query, including boundary value) + let lte_val = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_lte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Included(ScalarValue::Int32(Some(lte_val))), + ); + let full_result_lte = full_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_lte = merged_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lte, merged_result_lte, + "Less than or equal query (<={}) failed", + lte_val + ); + + // Test 15: Greater than or equal query (implemented using range query, including boundary value) + let gte_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(gte_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gte = full_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gte = merged_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_gte, merged_result_gte, + "Greater than or equal query (>={}) failed", + gte_val + ); + } + + #[test] + fn test_extract_partition_id() { + // Test valid partition file names + assert_eq!( + super::extract_partition_id("part_123_page_data.lance").unwrap(), + 123 + ); + assert_eq!( + super::extract_partition_id("part_456_page_lookup.lance").unwrap(), + 456 + ); + assert_eq!( + super::extract_partition_id("part_4294967296_page_data.lance").unwrap(), + 4294967296 + ); + + // Test invalid file names + assert!(super::extract_partition_id("invalid_filename.lance").is_err()); + assert!(super::extract_partition_id("part_abc_page_data.lance").is_err()); + assert!(super::extract_partition_id("part_123").is_err()); + assert!(super::extract_partition_id("part_").is_err()); } #[tokio::test] - async fn test_page_cache() { + async fn test_cleanup_partition_files() { + use crate::scalar::lance_format::LanceIndexStore; + use lance_core::cache::LanceCache; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use std::sync::Arc; + use tempfile::tempdir; + + // Create a test store let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + let test_store: Arc = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), Path::from_filesystem_path(tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - let data = gen_batch() - .col("value", array::step::()) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(1000), BatchCount::from(10)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + // Test files with different patterns + let lookup_files = vec![ + "part_123_page_lookup.lance".to_string(), + "invalid_lookup_file.lance".to_string(), + "part_456_page_lookup.lance".to_string(), + ]; - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) - .await - .unwrap(); + let page_files = vec![ + "part_123_page_data.lance".to_string(), + "invalid_page_file.lance".to_string(), + "part_456_page_data.lance".to_string(), + ]; - let index = BTreeIndex::load( - test_store, - None, - LanceCache::with_capacity(100 * 1024 * 1024), - ) - .await - .unwrap(); + // The cleanup function should handle both valid and invalid file patterns gracefully + // This test mainly verifies that the function doesn't panic and handles edge cases + super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; - let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); - let metrics = LocalMetricsCollector::default(); - let query1 = index.search(&query, &metrics); - let query2 = index.search(&query, &metrics); - tokio::join!(query1, query2).0.unwrap(); - assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + // If we get here without panicking, the cleanup function handled all cases correctly + assert!(true); } } diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 4edb1cb6a0a..a4506020782 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -163,6 +163,7 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -170,7 +171,8 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { source: "must provide training request created by new_training_request".into(), location: location!(), })?; - Self::train_inverted_index(data, index_store, request.parameters.clone(), None).await + Self::train_inverted_index(data, index_store, request.parameters.clone(), fragment_ids) + .await } /// Load an index from storage diff --git a/rust/lance-index/src/scalar/json.rs b/rust/lance-index/src/scalar/json.rs index 0b8a43efbe7..e36feaacfc7 100644 --- a/rust/lance-index/src/scalar/json.rs +++ b/rust/lance-index/src/scalar/json.rs @@ -768,6 +768,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -797,7 +798,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { )?; let target_index = target_plugin - .train_index(converted_stream, index_store, target_request) + .train_index(converted_stream, index_store, target_request, fragment_ids) .await?; let index_details = crate::pb::JsonIndexDetails { diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index 542aa2bc97a..64e932c47c5 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -398,6 +398,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let schema = data.schema(); let field = schema @@ -427,7 +428,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { let data = unnest_chunks(data)?; let bitmap_plugin = BitmapIndexPlugin; bitmap_plugin - .train_index(data, index_store, request) + .train_index(data, index_store, request, fragment_ids) .await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pb::LabelListIndexDetails::default()) diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index d8a95de1eeb..4df502ead09 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -368,7 +368,7 @@ pub mod tests { ) .unwrap(); btree_plugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -866,6 +866,7 @@ pub mod tests { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); @@ -911,7 +912,7 @@ pub mod tests { .new_training_request("{}", &Field::new(VALUE_COLUMN_NAME, DataType::Int32, false)) .unwrap(); BitmapIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -1399,7 +1400,7 @@ pub mod tests { ) .unwrap(); LabelListIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index ff559dd9292..586b0a4da9a 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -1285,6 +1285,7 @@ impl ScalarIndexPlugin for NGramIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_ngram_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 022da729f0c..3880aad4dbe 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -119,6 +119,7 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result; /// Returns true if the index returns an exact answer (e.g. not AtMost) diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 748ab003863..c9097cbb8cc 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -961,6 +961,7 @@ impl ScalarIndexPlugin for ZoneMapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + _fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 0742ff7f878..58b94f56318 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -71,6 +71,7 @@ impl BenchmarkFixture { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index cdebc399547..ccae72a4865 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -284,12 +284,11 @@ pub(super) async fn build_scalar_index( training_request.criteria(), None, train, - fragment_ids, + fragment_ids.clone(), ) .await?; - plugin - .train_index(training_data, &index_store, training_request) + .train_index(training_data, &index_store, training_request, fragment_ids) .await } From 7105d77d26e428d985f1178b8f6ff69dfe9cd71c Mon Sep 17 00:00:00 2001 From: xloya Date: Fri, 5 Sep 2025 12:48:21 +0800 Subject: [PATCH 03/13] support btree distributely --- java/.gitignore | 2 +- java/.project | 17 + .../org.eclipse.core.resources.prefs | 2 + java/.settings/org.eclipse.m2e.core.prefs | 4 + java/core/.classpath | 50 + java/core/.project | 23 + .../org.eclipse.core.resources.prefs | 5 + .../.settings/org.eclipse.jdt.apt.core.prefs | 2 + .../core/.settings/org.eclipse.jdt.core.prefs | 9 + .../core/.settings/org.eclipse.m2e.core.prefs | 4 + python/python/lance/dataset.py | 36 +- python/python/lance/lance/__init__.pyi | 4 +- python/python/tests/test_scalar_index.py | 416 +++- python/src/dataset.rs | 126 +- rust/lance-index/src/scalar/bitmap.rs | 1 + rust/lance-index/src/scalar/btree.rs | 2185 ++++++++++++++--- rust/lance-index/src/scalar/inverted.rs | 4 +- rust/lance-index/src/scalar/json.rs | 3 +- rust/lance-index/src/scalar/label_list.rs | 3 +- rust/lance-index/src/scalar/lance_format.rs | 7 +- rust/lance-index/src/scalar/ngram.rs | 1 + rust/lance-index/src/scalar/registry.rs | 1 + rust/lance-index/src/scalar/zonemap.rs | 1 + rust/lance/benches/scalar_index.rs | 1 + rust/lance/src/index/scalar.rs | 5 +- 25 files changed, 2576 insertions(+), 336 deletions(-) create mode 100644 java/.project create mode 100644 java/.settings/org.eclipse.core.resources.prefs create mode 100644 java/.settings/org.eclipse.m2e.core.prefs create mode 100644 java/core/.classpath create mode 100644 java/core/.project create mode 100644 java/core/.settings/org.eclipse.core.resources.prefs create mode 100644 java/core/.settings/org.eclipse.jdt.apt.core.prefs create mode 100644 java/core/.settings/org.eclipse.jdt.core.prefs create mode 100644 java/core/.settings/org.eclipse.m2e.core.prefs diff --git a/java/.gitignore b/java/.gitignore index d9074bd2835..2b82c700d45 100644 --- a/java/.gitignore +++ b/java/.gitignore @@ -1,2 +1,2 @@ *.iml -.java-version +.java-version \ No newline at end of file diff --git a/java/.project b/java/.project new file mode 100644 index 00000000000..9e430d58fe9 --- /dev/null +++ b/java/.project @@ -0,0 +1,17 @@ + + + lance-parent + + + + + + org.eclipse.m2e.core.maven2Builder + + + + + + org.eclipse.m2e.core.maven2Nature + + diff --git a/java/.settings/org.eclipse.core.resources.prefs b/java/.settings/org.eclipse.core.resources.prefs new file mode 100644 index 00000000000..99f26c0203a --- /dev/null +++ b/java/.settings/org.eclipse.core.resources.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +encoding/=UTF-8 diff --git a/java/.settings/org.eclipse.m2e.core.prefs b/java/.settings/org.eclipse.m2e.core.prefs new file mode 100644 index 00000000000..f897a7f1cb2 --- /dev/null +++ b/java/.settings/org.eclipse.m2e.core.prefs @@ -0,0 +1,4 @@ +activeProfiles= +eclipse.preferences.version=1 +resolveWorkspaceProjects=true +version=1 diff --git a/java/core/.classpath b/java/core/.classpath new file mode 100644 index 00000000000..5c8072ecc61 --- /dev/null +++ b/java/core/.classpath @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java/core/.project b/java/core/.project new file mode 100644 index 00000000000..4a9eedb6505 --- /dev/null +++ b/java/core/.project @@ -0,0 +1,23 @@ + + + lance-core + + + + + + org.eclipse.jdt.core.javabuilder + + + + + org.eclipse.m2e.core.maven2Builder + + + + + + org.eclipse.jdt.core.javanature + org.eclipse.m2e.core.maven2Nature + + diff --git a/java/core/.settings/org.eclipse.core.resources.prefs b/java/core/.settings/org.eclipse.core.resources.prefs new file mode 100644 index 00000000000..cdfe4f1b669 --- /dev/null +++ b/java/core/.settings/org.eclipse.core.resources.prefs @@ -0,0 +1,5 @@ +eclipse.preferences.version=1 +encoding//src/main/java=UTF-8 +encoding//src/test/java=UTF-8 +encoding//src/test/resources=UTF-8 +encoding/=UTF-8 diff --git a/java/core/.settings/org.eclipse.jdt.apt.core.prefs b/java/core/.settings/org.eclipse.jdt.apt.core.prefs new file mode 100644 index 00000000000..d4313d4b25e --- /dev/null +++ b/java/core/.settings/org.eclipse.jdt.apt.core.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.apt.aptEnabled=false diff --git a/java/core/.settings/org.eclipse.jdt.core.prefs b/java/core/.settings/org.eclipse.jdt.core.prefs new file mode 100644 index 00000000000..1b6e1ef22f9 --- /dev/null +++ b/java/core/.settings/org.eclipse.jdt.core.prefs @@ -0,0 +1,9 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 +org.eclipse.jdt.core.compiler.compliance=1.8 +org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled +org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning +org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore +org.eclipse.jdt.core.compiler.processAnnotations=disabled +org.eclipse.jdt.core.compiler.release=disabled +org.eclipse.jdt.core.compiler.source=1.8 diff --git a/java/core/.settings/org.eclipse.m2e.core.prefs b/java/core/.settings/org.eclipse.m2e.core.prefs new file mode 100644 index 00000000000..f897a7f1cb2 --- /dev/null +++ b/java/core/.settings/org.eclipse.m2e.core.prefs @@ -0,0 +1,4 @@ +activeProfiles= +eclipse.preferences.version=1 +resolveWorkspaceProjects=true +version=1 diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 376476e43af..54bd432201a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2731,8 +2731,40 @@ def prewarm_index(self, name: str): """ return self._ds.prewarm_index(name) - def merge_index_metadata(self, index_uuid: str): - return self._ds.merge_index_metadata(index_uuid) + def merge_index_metadata( + self, + index_uuid: str, + index_type: Union[ + Literal["BTREE"], + Literal["INVERTED"], + ], + prefetch_batch: Optional[int] = None, + ): + """ + Merge an index which not commit at present. + + Parameters + ---------- + index_uuid: str + The uuid of the index which want to merge. + index_type: Literal["BTREE", "INVERTED"] + The type of the index. + prefetch_batch: int, optional + The number of prefetch batches of sub-page files for merging. + Default 1. + """ + index_type = index_type.upper() + if index_type not in [ + "BTREE", + "INVERTED", + ]: + raise NotImplementedError( + ( + 'Only "BTREE" or "INVERTED" are supported for ' + f"merge index metadata. Received {index_type}", + ) + ) + return self._ds.merge_index_metadata(index_uuid, index_type, prefetch_batch) def session(self) -> Session: """ diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index c2a72b7b1b5..0bae8e2f1aa 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -282,7 +282,9 @@ class _Dataset: ): ... def drop_index(self, name: str): ... def prewarm_index(self, name: str): ... - def merge_index_metadata(self, index_uuid: str): ... + def merge_index_metadata( + self, index_uuid: str, index_type: str, prefetch_batch: Optional[int] = None + ): ... def count_fragments(self) -> int: ... def num_small_files(self, max_rows_per_group: int) -> int: ... def get_fragments(self) -> List[_Fragment]: ... diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index c2370a17a9e..5f92a1e4d11 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -1982,7 +1982,7 @@ def build_distributed_fts_index( ) # Merge the inverted index metadata - dataset.merge_index_metadata(index_id) + dataset.merge_index_metadata(index_id, index_type="INVERTED") # Create Index object for commit field_id = dataset.schema.get_field_index(column) @@ -2856,7 +2856,7 @@ def test_distribute_fts_index_build(tmp_path): print(f"Fragment {fragment_id} index created successfully") # Merge the inverted index metadata - ds.merge_index_metadata(index_id) + ds.merge_index_metadata(index_id, index_type="INVERTED") # Create an Index object using the new dataclass format from lance.dataset import Index @@ -2983,3 +2983,415 @@ def test_backward_compatibility_no_fragment_ids(tmp_path): results = ds.scanner(full_text_query=search_word).to_table() assert results.num_rows > 0 + + +def test_distribute_btree_index_build(tmp_path): + """ + Test distributed B-tree index build similar to test_distribute_fts_index_build. + This test creates B-tree indices on individual fragments and then + commits them as a single index. + """ + # Generate test dataset with multiple fragments + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=4, rows_per_fragment=10000 + ) + + import uuid + + index_id = str(uuid.uuid4()) + print(f"Using index ID: {index_id}") + index_name = "btree_multiple_fragment_idx" + + fragments = ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + for fragment in ds.get_fragments(): + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + # Create B-tree scalar index for each fragment + # Use the same index_name for all fragments (like in FTS test) + ds.create_scalar_index( + column="id", # Use integer column for B-tree + index_type="BTREE", + name=index_name, + replace=False, + fragment_uuid=index_id, + fragment_ids=[fragment_id], + ) + + # For fragment-level indexing, we expect the method to return successfully + # but not commit the index yet + print(f"Fragment {fragment_id} B-tree index created successfully") + + # Merge the B-tree index metadata + ds.merge_index_metadata(index_id, index_type="BTREE") + print(ds.uri) + + # Create an Index object using the new dataclass format + from lance.dataset import Index + + # Get the schema field for the indexed column + field_id = ds.schema.get_field_index("id") + + index = Index( + uuid=index_id, + name=index_name, + fields=[field_id], # Use field index instead of field object + dataset_version=ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Create the index operation + create_index_op = lance.LanceOperation.CreateIndex( + new_indices=[index], + removed_indices=[], + ) + + # Commit the index + ds_committed = lance.LanceDataset.commit( + ds.uri, + create_index_op, + read_version=ds.version, + ) + + print("Successfully committed multiple fragment B-tree index") + + # Verify the index was created and is functional + indices = ds_committed.list_indices() + assert len(indices) > 0, "No indices found after commit" + + # Find our index + our_index = None + for idx in indices: + if idx["name"] == index_name: + our_index = idx + break + + assert our_index is not None, f"Index '{index_name}' not found in indices list" + assert our_index["type"] == "BTree", ( + f"Expected BTree index, got {our_index['type']}" + ) + + # Test that the index works for searching + # Test exact equality queries + test_id = 100 # Should be in first fragment + results = ds_committed.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + print(f"Search for id = {test_id} returned {results.num_rows} results") + assert results.num_rows > 0, f"No results found for id = {test_id}" + + # Test range queries across fragments + results_range = ds_committed.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + print(f"Range query returned {results_range.num_rows} results") + assert results_range.num_rows > 0, "No results found for range query" + + # Compare with complete index results to ensure consistency + # Create a reference dataset with complete index + reference_ds = generate_multi_fragment_dataset( + tmp_path / "reference", num_fragments=4, rows_per_fragment=10000 + ) + + # Create complete B-tree index for comparison + reference_ds.create_scalar_index( + column="id", + index_type="BTREE", + name="reference_btree_idx", + ) + + # Compare exact query results + reference_results = reference_ds.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + assert results.num_rows == reference_results.num_rows, ( + f"Distributed index returned {results.num_rows} results, " + f"but complete index returned {reference_results.num_rows} results" + ) + + # Compare range query results + reference_range_results = reference_ds.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + assert results_range.num_rows == reference_range_results.num_rows, ( + f"Distributed index range query returned {results_range.num_rows} results, " + f"but complete index returned {reference_range_results.num_rows} results" + ) + + +def test_btree_precise_query_comparison(tmp_path): + """ + Precise comparison test between fragment-level B-tree index and complete + B-tree index. + This test creates identical datasets and compares query results in detail. + """ + # Test configuration + num_fragments = 3 + rows_per_fragment = 10000 + total_rows = num_fragments * rows_per_fragment + + print( + f"Creating datasets with {num_fragments} fragments," + f" {rows_per_fragment} rows each" + ) + + # Create dataset for fragment-level indexing + fragment_ds = generate_multi_fragment_dataset( + tmp_path / "fragment", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + # Create dataset for complete indexing (same data structure) + complete_ds = generate_multi_fragment_dataset( + tmp_path / "complete", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + import uuid + + # Build fragment-level B-tree index + fragment_index_id = str(uuid.uuid4()) + fragment_index_name = "fragment_btree_precise_test" + + fragments = fragment_ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + # Create fragment-level indices + for fragment in fragments: + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + fragment_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=fragment_index_name, + replace=False, + fragment_uuid=fragment_index_id, + fragment_ids=[fragment_id], + ) + + # Merge fragment indices + fragment_ds.merge_index_metadata(fragment_index_id, index_type="BTREE") + + # Create Index object for fragment-based index + from lance.dataset import Index + + field_id = fragment_ds.schema.get_field_index("id") + + fragment_index = Index( + uuid=fragment_index_id, + name=fragment_index_name, + fields=[field_id], + dataset_version=fragment_ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Commit fragment-based index + create_fragment_index_op = lance.LanceOperation.CreateIndex( + new_indices=[fragment_index], + removed_indices=[], + ) + + fragment_ds_committed = lance.LanceDataset.commit( + fragment_ds.uri, + create_fragment_index_op, + read_version=fragment_ds.version, + ) + + # Build complete B-tree index + complete_index_name = "complete_btree_precise_test" + complete_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=complete_index_name, + ) + + print("Both indices created successfully") + + # Detailed query comparison tests + test_cases = [ + # Test 1: Boundary values at fragment edges + {"name": "First value", "filter": "id = 0"}, + {"name": "Fragment 0 last value", "filter": f"id = {rows_per_fragment - 1}"}, + {"name": "Fragment 1 first value", "filter": f"id = {rows_per_fragment}"}, + { + "name": "Fragment 1 last value", + "filter": f"id = {2 * rows_per_fragment - 1}", + }, + {"name": "Fragment 2 first value", "filter": f"id = {2 * rows_per_fragment}"}, + {"name": "Last value", "filter": f"id = {total_rows - 1}"}, + # Test 2: Values in the middle of fragments + {"name": "Fragment 0 middle", "filter": f"id = {rows_per_fragment // 2}"}, + { + "name": "Fragment 1 middle", + "filter": f"id = {rows_per_fragment + rows_per_fragment // 2}", + }, + { + "name": "Fragment 2 middle", + "filter": f"id = {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 3: Range queries within single fragments + {"name": "Range within fragment 0", "filter": "id >= 10 AND id < 20"}, + { + "name": "Range within fragment 1", + "filter": f"id >= {rows_per_fragment + 10}" + f" AND id < {rows_per_fragment + 20}", + }, + { + "name": "Range within fragment 2", + "filter": f"id >= {2 * rows_per_fragment + 10}" + f" AND id < {2 * rows_per_fragment + 20}", + }, + # Test 4: Range queries spanning multiple fragments + { + "name": "Cross fragment 0-1", + "filter": f"id >= {rows_per_fragment - 5} AND id < {rows_per_fragment + 5}", + }, + { + "name": "Cross fragment 1-2", + "filter": f"id >= {2 * rows_per_fragment - 5}" + f" AND id < {2 * rows_per_fragment + 5}", + }, + { + "name": "Cross all fragments", + "filter": f"id >= {rows_per_fragment // 2} AND" + f" id < {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 5: Edge cases + {"name": "Non-existent small value", "filter": "id = -1"}, + {"name": "Non-existent large value", "filter": f"id = {total_rows + 100}"}, + {"name": "Large range", "filter": f"id >= 0 AND id < {total_rows}"}, + # Test 6: Comparison operators + {"name": "Less than boundary", "filter": f"id < {rows_per_fragment}"}, + { + "name": "Greater than boundary", + "filter": f"id > {2 * rows_per_fragment - 1}", + }, + {"name": "Less than or equal", "filter": f"id <= {rows_per_fragment + 50}"}, + {"name": "Greater than or equal", "filter": f"id >= {rows_per_fragment + 50}"}, + ] + + print(f"\nRunning {len(test_cases)} detailed comparison tests:") + + for i, test_case in enumerate(test_cases, 1): + test_name = test_case["name"] + filter_expr = test_case["filter"] + + print(f" {i:2d}. Testing {test_name}: {filter_expr}") + + # Query fragment-based index + fragment_results = fragment_ds_committed.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Query complete index + complete_results = complete_ds.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Compare row counts + assert fragment_results.num_rows == complete_results.num_rows, ( + f"Test '{test_name}' failed: Fragment index " + f"returned {fragment_results.num_rows} rows, " + f"but complete index returned {complete_results.num_rows}" + f" rows for filter: {filter_expr}" + ) + + # Compare actual results if there are any + if fragment_results.num_rows > 0: + # Sort both results by id for comparison + fragment_ids = sorted(fragment_results.column("id").to_pylist()) + complete_ids = sorted(complete_results.column("id").to_pylist()) + + assert fragment_ids == complete_ids, ( + f"Test '{test_name}' failed: Fragment index" + f" returned different IDs than complete index. " + f"Fragment IDs:" + f" {fragment_ids[:10]}{'...' if len(fragment_ids) > 10 else ''}, " + f"Complete IDs:" + f" {complete_ids[:10]}{'...' if len(complete_ids) > 10 else ''}" + ) + + print(f" āœ“ Passed ({fragment_results.num_rows} rows)") + + print(f"\nāœ… All {len(test_cases)} precision tests passed!") + print( + "Fragment-level B-tree index produces identical results" + " to complete B-tree index." + ) + + +def test_btree_fragment_ids_parameter_validation(tmp_path): + """ + Test validation of fragment_ids parameter for B-tree indices. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # Test with valid fragment IDs + fragments = ds.get_fragments() + valid_fragment_id = fragments[0].fragment_id + + # This should work without errors + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[valid_fragment_id], + ) + + # Test with invalid fragment ID (should handle gracefully) + try: + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[999999], # Non-existent fragment ID + ) + except Exception as e: + # It's acceptable for this to fail with an appropriate error + print(f"Expected error for invalid fragment ID: {e}") + + +def test_btree_backward_compatibility_no_fragment_ids(tmp_path): + """ + Test that B-tree indexing remains backward compatible + when fragment_ids is not provided. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # This should work exactly as before (full dataset indexing) + ds.create_scalar_index( + column="id", + index_type="BTREE", + name="full_dataset_btree_idx", + ) + + # Verify the index was created + indices = ds.list_indices() + assert len(indices) == 1 + assert indices[0]["name"] == "full_dataset_btree_idx" + assert indices[0]["type"] == "BTree" + + # Test that the index works + results = ds.scanner(filter="id = 50").to_table() + assert results.num_rows > 0 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 83da7f26dc9..f3d0fd10e83 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1670,47 +1670,111 @@ impl Dataset { .infer_error() } - #[pyo3(signature = (index_uuid))] - fn merge_index_metadata(&self, index_uuid: &str) -> PyResult<()> { + #[pyo3(signature = (index_uuid, index_type, prefetch_batch))] + fn merge_index_metadata( + &self, + index_uuid: &str, + index_type: &str, + prefetch_batch: Option, + ) -> PyResult<()> { RT.block_on(None, async { + let index_type = index_type.to_uppercase(); + let idx_type = match index_type.as_str() { + "BTREE" => IndexType::BTree, + "INVERTED" => IndexType::Inverted, + _ => { + return Err(Error::InvalidInput { + source: format!( + "Index type {} is not supported.", + index_type + ).into(), + location: location!(), + }); + } + }; + let store = LanceIndexStore::from_dataset_for_new(self.ds.as_ref(), index_uuid)?; let index_dir = self.ds.indices_dir().child(index_uuid); + if idx_type == IndexType::Inverted { + // List all partition metadata files in the index directory + let mut part_metadata_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") + { + part_metadata_files.push(file_name.to_string()); + } + } + Err(_) => continue, + } + } + + if part_metadata_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {}", + index_dir + ) + .into(), + location: location!(), + }); + } - // List all partition metadata files in the index directory - let mut part_metadata_files = Vec::new(); - let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); - - while let Some(item) = list_stream.next().await { - match item { - Ok(meta) => { - let file_name = meta.location.filename().unwrap_or_default(); - // Filter files matching the pattern part_*_metadata.lance - if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") - { - part_metadata_files.push(file_name.to_string()); + // Call merge_metadata_files function for inverted index + lance_index::scalar::inverted::builder::merge_metadata_files( + Arc::new(store), + &part_metadata_files, + ) + .await + } else { + // List all partition page / lookup files in the index directory + let mut part_page_files = Vec::new(); + let mut part_lookup_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_page_data.lance") + { + part_page_files.push(file_name.to_string()); + } + if file_name.starts_with("part_") && file_name.ends_with("_page_lookup.lance") + { + part_lookup_files.push(file_name.to_string()); + } } + Err(_) => continue, } - Err(_) => continue, } - } + if part_page_files.is_empty() || part_lookup_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {} (page_files: {}, lookup_files: {})", + index_dir, part_page_files.len(), part_lookup_files.len() + ) + .into(), + location: location!(), + }); + } - if part_metadata_files.is_empty() { - return Err(Error::InvalidInput { - source: format!( - "No partition metadata files found in index directory: {}", - index_dir - ) - .into(), - location: location!(), - }); + // Call merge_metadata_files function for btree index + lance_index::scalar::btree::merge_metadata_files( + Arc::new(store), + &part_page_files, + &part_lookup_files, + prefetch_batch, + ).await } - // Call merge_metadata_files function for inverted index - lance_index::scalar::inverted::builder::merge_metadata_files( - Arc::new(store), - &part_metadata_files, - ) - .await + })? .map_err(|err| PyValueError::new_err(err.to_string())) } diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 1b5d0d530bd..09dde5297b4 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -528,6 +528,7 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_bitmap_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index b2760fe214d..9c3e6685703 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -4,10 +4,10 @@ use std::{ any::Any, cmp::Ordering, - collections::{BTreeMap, BinaryHeap, HashMap}, + collections::{BTreeMap, BinaryHeap, HashMap, VecDeque}, fmt::{Debug, Display}, ops::Bound, - sync::Arc, + sync::{Arc, LazyLock}, }; use super::{ @@ -38,7 +38,7 @@ use deepsize::DeepSizeOf; use futures::{ future::BoxFuture, stream::{self}, - FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, + Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, }; use lance_core::{ cache::{CacheKey, LanceCache}, @@ -54,16 +54,37 @@ use lance_datafusion::{ chunker::chunk_concat_stream, exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; -use log::debug; +use log::{debug, warn}; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize, Serializer}; use snafu::location; +use tokio::runtime::{Builder, Runtime}; use tracing::info; const BTREE_LOOKUP_NAME: &str = "page_lookup.lance"; const BTREE_PAGES_NAME: &str = "page_data.lance"; pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; const BATCH_SIZE_META_KEY: &str = "batch_size"; + +/// Global thread pool for B-tree prefetch operations +static BTREE_PREFETCH_RUNTIME: LazyLock = LazyLock::new(|| { + Builder::new_multi_thread() + .worker_threads(get_num_compute_intensive_cpus()) + .max_blocking_threads(get_num_compute_intensive_cpus()) + .thread_name("lance-btree-prefetch") + .enable_time() + .build() + .expect("Failed to create B-tree prefetch runtime") +}); + +/// Spawn a prefetch task on the B-tree thread pool +fn spawn_btree_prefetch(future: F) -> tokio::task::JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + BTREE_PREFETCH_RUNTIME.spawn(future) +} const BTREE_INDEX_VERSION: u32 = 0; pub(crate) const BTREE_VALUES_COLUMN: &str = "values"; pub(crate) const BTREE_IDS_COLUMN: &str = "ids"; @@ -1231,6 +1252,7 @@ impl ScalarIndex for BTreeIndex { self.sub_index.as_ref(), dest_store, DEFAULT_BTREE_BATCH_SIZE, + None, ) .await?; @@ -1366,10 +1388,33 @@ pub async fn train_btree_index( sub_index_trainer: &dyn BTreeSubIndex, index_store: &dyn IndexStore, batch_size: u64, + fragment_ids: Option>, ) -> Result<()> { - let mut sub_index_file = index_store - .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) - .await?; + let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| { + if !frag_ids.is_empty() { + // Create a mask with fragment_id in high 32 bits for distributed indexing + // This mask is used to filter partitions belonging to specific fragments + // If multiple fragments processed, use first fragment_id <<32 as mask + Some((frag_ids[0] as u64) << 32) + } else { + None + } + }); + + let mut sub_index_file; + if fragment_mask.is_none() { + sub_index_file = index_store + .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) + .await?; + } else { + sub_index_file = index_store + .new_index_file( + part_page_data_file_path(fragment_mask.unwrap()).as_str(), + sub_index_trainer.schema().clone(), + ) + .await?; + } + let mut encoded_batches = Vec::new(); let mut batch_idx = 0; @@ -1393,385 +1438,1945 @@ pub async fn train_btree_index( file_schema .metadata .insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); - let mut btree_index_file = index_store - .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) - .await?; + let mut btree_index_file; + if fragment_mask.is_none() { + btree_index_file = index_store + .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) + .await?; + } else { + btree_index_file = index_store + .new_index_file( + part_lookup_file_path(fragment_mask.unwrap()).as_str(), + Arc::new(file_schema), + ) + .await?; + } btree_index_file.write_record_batch(record_batch).await?; btree_index_file.finish().await?; Ok(()) } -/// A stream that reads the original training data back out of the index -/// -/// This is used for updating the index -struct IndexReaderStream { - reader: Arc, - batch_size: u64, - num_batches: u32, - batch_idx: u32, +/// Extract partition ID from partition file name +/// Expected format: "part_{partition_id}_{suffix}.lance" +fn extract_partition_id(filename: &str) -> Result { + if !filename.starts_with("part_") { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + let parts: Vec<&str> = filename.split('_').collect(); + if parts.len() < 3 { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + parts[1].parse::().map_err(|_| Error::Internal { + message: format!("Failed to parse partition ID from filename: {}", filename), + location: location!(), + }) } -impl IndexReaderStream { - async fn new(reader: Arc, batch_size: u64) -> Self { - let num_batches = reader.num_batches(batch_size).await; - Self { - reader, - batch_size, - num_batches, - batch_idx: 0, +/// Merge multiple partition page / lookup files into a complete metadata file +/// +/// In a distributed environment, each worker node writes partition page / lookup files for the partitions it processes, +/// and this function merges these files into a final metadata file. +pub async fn merge_metadata_files( + store: Arc, + part_page_files: &[String], + part_lookup_files: &[String], + prefetch_batch: Option, +) -> Result<()> { + if part_lookup_files.is_empty() || part_page_files.is_empty() { + return Err(Error::Internal { + message: "No partition files provided for merging".to_string(), + location: location!(), + }); + } + + // Step 1: Create lookup map for page files by partition ID + let mut page_files_map = HashMap::new(); + for page_file in part_page_files { + let partition_id = extract_partition_id(page_file)?; + page_files_map.insert(partition_id, page_file); + } + + // Step 2: Validate that all lookup files have corresponding page files + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + if !page_files_map.contains_key(&partition_id) { + return Err(Error::Internal { + message: format!( + "No corresponding page file found for lookup file: {} (partition_id: {})", + lookup_file, partition_id + ), + location: location!(), + }); } } + + // Step 3: Extract metadata from lookup files + let first_lookup_reader = store.open_index_file(&part_lookup_files[0]).await?; + let batch_size = first_lookup_reader + .schema() + .metadata + .get(BATCH_SIZE_META_KEY) + .map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE)) + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE); + + // Get the value type from lookup schema (min column) + let lookup_batch = first_lookup_reader.read_range(0..1, None).await?; + let value_type = lookup_batch.column(0).data_type().clone(); + + // Get page schema first + let partition_id = extract_partition_id(part_lookup_files[0].as_str())?; + let page_file = page_files_map.get(&partition_id).unwrap(); + let page_reader = store.open_index_file(page_file).await?; + let page_schema = page_reader.schema().clone(); + + let arrow_schema = Arc::new(Schema::from(&page_schema)); + let mut page_file = store + .new_index_file(BTREE_PAGES_NAME, arrow_schema.clone()) + .await?; + + let mut prefetch_config = PrefetchConfig::default(); + if prefetch_batch.is_some() { + prefetch_config = prefetch_config.with_prefetch_batch(prefetch_batch.unwrap()); + } + + let lookup_entries = merge_page( + part_lookup_files, + &page_files_map, + &store, + batch_size, + &mut page_file, + arrow_schema.clone(), + prefetch_config, + ) + .await?; + + page_file.finish().await?; + + // Step 4: Generate new lookup file based on reorganized pages + // Add batch_size to schema metadata + let mut metadata = HashMap::new(); + metadata.insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); + + let lookup_schema_with_metadata = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("min", value_type.clone(), true), + Field::new("max", value_type, true), + Field::new("null_count", DataType::UInt32, false), + Field::new("page_idx", DataType::UInt32, false), + ], + metadata, + )); + + let lookup_batch = RecordBatch::try_new( + lookup_schema_with_metadata.clone(), + vec![ + ScalarValue::iter_to_array(lookup_entries.iter().map(|(min, _, _, _)| min.clone()))?, + ScalarValue::iter_to_array(lookup_entries.iter().map(|(_, max, _, _)| max.clone()))?, + Arc::new(UInt32Array::from_iter_values( + lookup_entries + .iter() + .map(|(_, _, null_count, _)| *null_count), + )), + Arc::new(UInt32Array::from_iter_values( + lookup_entries.iter().map(|(_, _, _, page_idx)| *page_idx), + )), + ], + )?; + + let mut lookup_file = store + .new_index_file(BTREE_LOOKUP_NAME, lookup_schema_with_metadata) + .await?; + lookup_file.write_record_batch(lookup_batch).await?; + lookup_file.finish().await?; + + // After successfully writing the merged files, delete all partition files + // Only perform deletion after files are successfully written, ensuring debug information is not lost in case of failure + cleanup_partition_files(&store, part_lookup_files, part_page_files).await; + + Ok(()) } -impl Stream for IndexReaderStream { - type Item = BoxFuture<'static, Result>; +/// Clean up partition files after successful merge +/// +/// This function safely deletes partition lookup and page files after a successful merge operation. +/// File deletion failures are logged but do not affect the overall success of the merge operation. +async fn cleanup_partition_files( + store: &Arc, + part_lookup_files: &[String], + part_page_files: &[String], +) { + // Clean up partition lookup files + for file_name in part_lookup_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_lookup.lance", + "partition lookup", + ) + .await; + } - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - if this.batch_idx >= this.num_batches { - return std::task::Poll::Ready(None); - } - let batch_num = this.batch_idx; - this.batch_idx += 1; - let reader_copy = this.reader.clone(); - let batch_size = this.batch_size; - let read_task = async move { - reader_copy - .read_record_batch(batch_num as u64, batch_size) - .await - } - .boxed(); - std::task::Poll::Ready(Some(read_task)) + // Clean up partition page files + for file_name in part_page_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_data.lance", + "partition page", + ) + .await; } } -/// Parameters for a btree index -#[derive(Debug, Serialize, Deserialize)] -pub struct BTreeParameters { - /// The number of rows to include in each zone - pub zone_size: Option, +/// Helper function to clean up a single partition file +/// +/// Performs safety checks on the filename pattern before attempting deletion. +async fn cleanup_single_file( + store: &Arc, + file_name: &str, + expected_prefix: &str, + expected_suffix: &str, + file_type: &str, +) { + // Ensure we only delete files that match the expected pattern (safety check) + if file_name.starts_with(expected_prefix) && file_name.ends_with(expected_suffix) { + match store.delete_index_file(file_name).await { + Ok(()) => { + debug!("Successfully deleted {} file: {}", file_type, file_name); + } + Err(e) => { + // File deletion failures should not affect the overall success of the function + // Log the error but continue processing other files + warn!( + "Failed to delete {} file '{}': {}. \ + This does not affect the merge operation, but may leave \ + partition files that should be cleaned up manually.", + file_type, file_name, e + ); + } + } + } else { + // If the filename doesn't match the expected format, log a warning but don't attempt deletion + warn!( + "Skipping deletion of file '{}' as it does not match the expected \ + {} file pattern ({}*{})", + file_name, file_type, expected_prefix, expected_suffix + ); + } } -struct BTreeTrainingRequest { - parameters: BTreeParameters, - criteria: TrainingCriteria, +/// Prefetch configuration for partition iterators +#[derive(Debug, Clone)] +pub struct PrefetchConfig { + /// Number of batches to prefetch ahead (0 means no prefetching) + pub prefetch_batches: usize, } -impl BTreeTrainingRequest { - pub fn new(parameters: BTreeParameters) -> Self { +impl Default for PrefetchConfig { + fn default() -> Self { Self { - parameters, - // BTree indexes need data sorted by the value column - criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + prefetch_batches: 1, } } } -impl TrainingRequest for BTreeTrainingRequest { - fn as_any(&self) -> &dyn std::any::Any { - self +impl PrefetchConfig { + /// Set the prefetch batch count + pub fn with_prefetch_batch(&self, batch_count: usize) -> Self { + Self { + prefetch_batches: batch_count, + } } +} - fn criteria(&self) -> &TrainingCriteria { - &self.criteria - } +/// Buffer entry for prefetch queue +#[derive(Debug)] +struct BufferEntry { + batch: RecordBatch, + start_row: usize, + end_row: usize, } -#[derive(Debug, Default)] -pub struct BTreeIndexPlugin; +/// Running prefetch task information +#[derive(Debug)] +struct RunningPrefetchTask { + /// Task handle + handle: tokio::task::JoinHandle<()>, + /// Range being prefetched + range: std::ops::Range, +} -#[async_trait] -impl ScalarIndexPlugin for BTreeIndexPlugin { - fn new_training_request( - &self, - params: &str, - field: &Field, - ) -> Result> { - if field.data_type().is_nested() { - return Err(Error::InvalidInput { - source: "A btree index can only be created on a non-nested field.".into(), - location: location!(), - }); - } +/// Check if two ranges overlap +fn ranges_overlap(range1: &std::ops::Range, range2: &std::ops::Range) -> bool { + range1.start < range2.end && range2.start < range1.end +} - let params = serde_json::from_str::(params)?; - Ok(Box::new(BTreeTrainingRequest::new(params))) - } +/// Prefetch state for a partition using task-based prefetching +struct PartitionPrefetchState { + /// Queue of prefetched data + buffer: Arc>>, + /// Reader for this partition + reader: Arc, + /// Total rows in this partition + total_rows: usize, + /// Queue of running prefetch tasks with their ranges + running_tasks: Arc>>, + /// Next position to schedule for prefetch + next_prefetch_position: Arc>, +} - fn provides_exact_answer(&self) -> bool { - true - } +/// Manager for coordinating task-based prefetch across multiple partitions +pub struct PrefetchManager { + /// Prefetch state per partition + partition_states: HashMap, + /// Prefetch configuration + config: PrefetchConfig, +} - fn version(&self) -> u32 { - 0 +impl PrefetchManager { + /// Create a new prefetch manager + pub fn new(config: PrefetchConfig) -> Self { + Self { + partition_states: HashMap::new(), + config, + } } - fn new_query_parser( - &self, - index_name: String, - _index_details: &prost_types::Any, - ) -> Option> { - Some(Box::new(SargableQueryParser::new(index_name, false))) - } + /// Initialize a partition for task-based prefetching + pub fn initialize_partition(&mut self, partition_id: u64, reader: Arc) { + let total_rows = reader.num_rows(); + let buffer = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let running_tasks = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let next_prefetch_position = Arc::new(tokio::sync::Mutex::new(0)); - async fn train_index( - &self, - data: SendableRecordBatchStream, - index_store: &dyn IndexStore, - request: Box, - ) -> Result { - let request = request - .as_any() - .downcast_ref::() - .unwrap(); - let value_type = data - .schema() - .field_with_name(VALUE_COLUMN_NAME)? - .data_type() - .clone(); - let flat_index_trainer = FlatIndexMetadata::new(value_type); - train_btree_index( - data, - &flat_index_trainer, - index_store, - request - .parameters - .zone_size - .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), - ) - .await?; - Ok(CreatedIndex { - index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), - index_version: BTREE_INDEX_VERSION, - }) - } + let state = PartitionPrefetchState { + buffer, + reader, + total_rows, + running_tasks, + next_prefetch_position, + }; - async fn load_index( - &self, - index_store: Arc, - _index_details: &prost_types::Any, - frag_reuse_index: Option>, - cache: LanceCache, - ) -> Result> { - Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + self.partition_states.insert(partition_id, state); + debug!( + "Initialized partition {} for task-based prefetching", + partition_id + ); } -} -#[cfg(test)] -mod tests { - use std::sync::atomic::Ordering; - use std::{collections::HashMap, sync::Arc}; + /// Submit a prefetch task for a partition to the thread pool + pub async fn submit_prefetch_task(&self, partition_id: u64, batch_size: usize) -> Result<()> { + if self.config.prefetch_batches == 0 { + return Ok(()); + } - use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; - use arrow_array::FixedSizeListArray; - use arrow_schema::DataType; - use datafusion::{ - execution::{SendableRecordBatchStream, TaskContext}, - physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, - }; - use datafusion_common::{DataFusionError, ScalarValue}; - use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; - use deepsize::DeepSizeOf; - use futures::TryStreamExt; - use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; - use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; - use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; - use lance_io::object_store::ObjectStore; - use object_store::path::Path; - use tempfile::tempdir; + let Some(state) = self.partition_states.get(&partition_id) else { + return Ok(()); + }; - use crate::metrics::LocalMetricsCollector; - use crate::{ - metrics::NoOpMetricsCollector, - scalar::{ - btree::{BTreeIndex, BTREE_PAGES_NAME}, - flat::FlatIndexMetadata, - lance_format::LanceIndexStore, - IndexStore, SargableQuery, ScalarIndex, SearchResult, - }, - }; + let reader = state.reader.clone(); + let buffer = state.buffer.clone(); + let running_tasks = state.running_tasks.clone(); + let next_prefetch_position = state.next_prefetch_position.clone(); + let total_rows = state.total_rows; + let effective_batch_size = self.config.prefetch_batches * batch_size; - use super::{train_btree_index, OrderableScalarValue}; + const MAX_BUFFER_SIZE: usize = 4; + const MAX_RUNNING_TASKS: usize = 2; - #[test] - fn test_scalar_value_size() { - let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); - let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( - FixedSizeListArray::from_iter_primitive::( - vec![Some(vec![Some(0); 128])], - 128, - ), - ))) - .deep_size_of(); + // Clean up completed tasks and check limits + { + let mut tasks_guard = running_tasks.lock().await; - // deep_size_of should account for the rust type overhead - assert!(size_of_i32 > 4); - assert!(size_of_many_i32 > 128 * 4); + // Remove completed tasks from the front + while let Some(task) = tasks_guard.front() { + if task.handle.is_finished() { + tasks_guard.pop_front(); + } else { + break; + } + } + + // Check if we have too many running tasks + if tasks_guard.len() >= MAX_RUNNING_TASKS { + debug!( + "Skipping prefetch for partition {} - too many running tasks ({})", + partition_id, + tasks_guard.len() + ); + return Ok(()); + } + + // Check if any running task already covers to the end of file + for task in tasks_guard.iter() { + if task.range.end >= total_rows { + debug!( + "Skipping prefetch for partition {} - task already covers to EOF (range {}..{})", + partition_id, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // Check if buffer is full + { + let buffer_guard = buffer.lock().await; + if buffer_guard.len() >= MAX_BUFFER_SIZE { + debug!( + "Skipping prefetch for partition {} - buffer full", + partition_id + ); + return Ok(()); + } + } + + // Determine the next range to prefetch + let next_range = { + let mut pos_guard = next_prefetch_position.lock().await; + let start_pos = *pos_guard; + + if start_pos >= total_rows { + debug!( + "Skipping prefetch for partition {} - no more data to prefetch", + partition_id + ); + return Ok(()); + } + + let end_pos = std::cmp::min(start_pos + effective_batch_size, total_rows); + *pos_guard = end_pos; // Update next prefetch position + start_pos..end_pos + }; + + // Check if this range is already being prefetched + { + let tasks_guard = running_tasks.lock().await; + + // Check for range overlap + for task in tasks_guard.iter() { + if ranges_overlap(&task.range, &next_range) { + debug!( + "Skipping prefetch for partition {} - range {}..{} overlaps with running task {}..{}", + partition_id, next_range.start, next_range.end, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // All checks passed, create the actual prefetch task (only this part is async) + let range_clone = next_range.clone(); + let running_tasks_for_cleanup = running_tasks.clone(); + + let prefetch_task = spawn_btree_prefetch(async move { + // Perform the actual read + match reader.read_range(range_clone.clone(), None).await { + Ok(batch) => { + let entry = BufferEntry { + batch, + start_row: range_clone.start, + end_row: range_clone.end, + }; + + // Add to buffer + { + let mut buffer_guard = buffer.lock().await; + buffer_guard.push_back(entry); + } + + debug!( + "Prefetched {} rows ({}..{}) for partition {}", + range_clone.end - range_clone.start, + range_clone.start, + range_clone.end, + partition_id + ); + } + Err(err) => { + warn!( + "Prefetch task failed for partition {} range {}..{}: {}", + partition_id, range_clone.start, range_clone.end, err + ); + } + } + + // Remove this task from running tasks when completed + { + let mut tasks_guard = running_tasks_for_cleanup.lock().await; + tasks_guard.retain(|task| !task.handle.is_finished()); + } + }); + + // Add the task to running tasks + { + let mut tasks_guard = running_tasks.lock().await; + tasks_guard.push_back(RunningPrefetchTask { + handle: prefetch_task, + range: next_range.clone(), + }); + } + + debug!( + "Submitted prefetch task for partition {} range {}..{}", + partition_id, next_range.start, next_range.end + ); + + Ok(()) + } + + /// Get data from buffer or fallback to direct read + pub async fn get_data_with_fallback( + &self, + partition_id: u64, + start_row: usize, + end_row: usize, + ) -> Result { + if let Some(state) = self.partition_states.get(&partition_id) { + // First try to get from buffer + { + let mut buffer_guard = state.buffer.lock().await; + + // Remove outdated entries from the front + while let Some(entry) = buffer_guard.front() { + if entry.end_row <= start_row { + buffer_guard.pop_front(); + } else { + break; + } + } + + // Check if we have suitable data in buffer + if let Some(entry) = buffer_guard.front() { + if entry.start_row <= start_row && entry.end_row >= end_row { + // Found matching data, extract it + let entry = buffer_guard.pop_front().unwrap(); + drop(buffer_guard); + + let slice_start = start_row - entry.start_row; + let slice_len = end_row - start_row; + + debug!( + "Using buffered data for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + return Ok(entry.batch.slice(slice_start, slice_len)); + } + } + } + + // Fallback to direct read + debug!( + "Direct read fallback for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + state.reader.read_range(start_row..end_row, None).await + } else { + Err(Error::Internal { + message: format!("Partition {} not found in prefetch manager", partition_id), + location: location!(), + }) + } + } +} + +/// Simplified partition iterator with immediate loading since all partitions need to be accessed +struct PartitionIterator { + reader: Arc, + current_batch: Option, + current_position: usize, + rows_read: usize, + partition_id: u64, + batch_size: u64, +} + +impl PartitionIterator { + async fn new( + store: Arc, + page_file_name: String, + partition_id: u64, + batch_size: u64, + ) -> Result { + let reader = store.open_index_file(&page_file_name).await?; + Ok(Self { + reader, + current_batch: None, + current_position: 0, + rows_read: 0, + partition_id, + batch_size, + }) + } + + /// Get the next element, working with the prefetch manager + async fn next( + &mut self, + prefetch_manager: &PrefetchManager, + ) -> Result> { + // Load new batch if current one is exhausted + if self.needs_new_batch() { + if self.rows_read >= self.reader.num_rows() { + return Ok(None); + } + self.load_next_batch(prefetch_manager).await?; + + // Submit next prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } else { + // Check if we've read half of the current batch, submit next prefetch task + let batch_half = self.current_batch.as_ref().unwrap().num_rows() / 2; + if self.current_position == batch_half && batch_half > 0 { + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } + } + + // Extract next value from current batch + if let Some(batch) = &self.current_batch { + let value = ScalarValue::try_from_array(batch.column(0), self.current_position)?; + let row_id = ScalarValue::try_from_array(batch.column(1), self.current_position)?; + self.current_position += 1; + self.rows_read += 1; + Ok(Some((value, row_id))) + } else { + Ok(None) + } + } + + /// Check if we need to load a new batch + fn needs_new_batch(&self) -> bool { + self.current_batch.is_none() + || self.current_position >= self.current_batch.as_ref().unwrap().num_rows() + } + + async fn load_next_batch(&mut self, prefetch_manager: &PrefetchManager) -> Result<()> { + let remaining_rows = self.reader.num_rows() - self.rows_read; + if remaining_rows == 0 { + self.current_batch = None; + return Ok(()); + } + + let rows_to_read = std::cmp::min(self.batch_size as usize, remaining_rows); + let end_row = self.rows_read + rows_to_read; + + // Use the new fallback mechanism - try buffer first, then direct read + let batch = prefetch_manager + .get_data_with_fallback(self.partition_id, self.rows_read, end_row) + .await?; + + self.current_batch = Some(batch); + self.current_position = 0; + + Ok(()) + } + + fn get_reader(&self) -> Arc { + self.reader.clone() + } +} + +/// Heap elements, used for priority queues in multi-way merging +#[derive(Debug)] +struct HeapElement { + value: ScalarValue, + row_id: ScalarValue, + partition_id: u64, +} + +impl PartialEq for HeapElement { + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } +} + +impl Eq for HeapElement {} + +impl PartialOrd for HeapElement { + fn partial_cmp(&self, other: &Self) -> Option { + // Note: BinaryHeap is a maximum heap, we need a minimum heap, + // so reverse the comparison result + other.value.partial_cmp(&self.value) + } +} + +impl Ord for HeapElement { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap_or(Ordering::Equal) + } +} + +async fn merge_page( + part_lookup_files: &[String], + page_files_map: &HashMap, + store: &Arc, + batch_size: u64, + page_file: &mut Box, + arrow_schema: Arc, + prefetch_config: PrefetchConfig, +) -> Result> { + let mut lookup_entries = Vec::new(); + let mut page_idx = 0u32; + + debug!( + "Starting multi-way merge with {} partitions using prefetch manager", + part_lookup_files.len() + ); + + // Create prefetch manager + let mut prefetch_manager = PrefetchManager::new(prefetch_config.clone()); + + // Directly create iterators and read first element + let mut partition_map = HashMap::new(); + let mut heap = BinaryHeap::new(); + + debug!("Initializing {} partitions", part_lookup_files.len()); + + // Initialize all partitions + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + let page_file_name = page_files_map + .get(&partition_id) + .ok_or_else(|| Error::Internal { + message: format!("Page file not found for partition ID: {}", partition_id), + location: location!(), + })? + .to_string(); + + let mut iterator = + PartitionIterator::new(store.clone(), page_file_name, partition_id, batch_size).await?; + + // Initialize partition in prefetch manager + let reader = iterator.get_reader(); + prefetch_manager.initialize_partition(partition_id, reader); + + // Submit initial prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(partition_id, batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + partition_id, err + ); + } + + let first_element = iterator.next(&prefetch_manager).await?; + + if let Some((value, row_id)) = first_element { + // Put the first element into the heap + heap.push(HeapElement { + value, + row_id, + partition_id, + }); + } + + partition_map.insert(partition_id, iterator); + } + + debug!( + "Initialized {} partitions, heap size: {}", + partition_map.len(), + heap.len() + ); + + let mut current_batch_rows = Vec::with_capacity(batch_size as usize); + let mut total_merged = 0usize; + + // Multi-way merge main loop + while let Some(min_element) = heap.pop() { + // Add current minimum element to batch + current_batch_rows.push((min_element.value, min_element.row_id)); + total_merged += 1; + + // Read next element from corresponding partition + if let Some(iterator) = partition_map.get_mut(&min_element.partition_id) { + if let Some((next_value, next_row_id)) = iterator.next(&prefetch_manager).await? { + heap.push(HeapElement { + value: next_value, + row_id: next_row_id, + partition_id: min_element.partition_id, + }); + } + } + + // Write when batch reaches specified size + if current_batch_rows.len() >= batch_size as usize { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + } + + // Write the remaining data + if !current_batch_rows.is_empty() { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + + debug!( + "Completed multi-way merge: merged {} rows into {} lookup entries", + total_merged, + lookup_entries.len() + ); + Ok(lookup_entries) +} + +/// Helper function to prepare batch data in parallel +async fn prepare_batch_data( + batch_rows: Vec<(ScalarValue, ScalarValue)>, + arrow_schema: Arc, + page_idx: u32, +) -> Result<(RecordBatch, (ScalarValue, ScalarValue, u32, u32))> { + if batch_rows.is_empty() { + return Err(Error::Internal { + message: "Cannot prepare empty batch".to_string(), + location: location!(), + }); + } + + // Parallelize data preparation + let (values, row_ids): (Vec<_>, Vec<_>) = batch_rows.into_iter().unzip(); + + // Convert to arrays in parallel using rayon or manually spawn tasks + let values_array = ScalarValue::iter_to_array(values.into_iter())?; + let row_ids_array = ScalarValue::iter_to_array(row_ids.into_iter())?; + + let batch = RecordBatch::try_new(arrow_schema, vec![values_array, row_ids_array])?; + + // Calculate min/max/null_count for lookup entry + let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; + let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; + let null_count = batch.column(0).null_count() as u32; + + let lookup_entry = (min_val, max_val, null_count, page_idx); + + Ok((batch, lookup_entry)) +} + +/// Helper function to write a batch and create lookup entry +async fn write_batch_and_lookup_entry( + batch_rows: &mut Vec<(ScalarValue, ScalarValue)>, + page_file: &mut Box, + arrow_schema: &Arc, + lookup_entries: &mut Vec<(ScalarValue, ScalarValue, u32, u32)>, + page_idx: &mut u32, +) -> Result<()> { + if batch_rows.is_empty() { + return Ok(()); + } + + // Take ownership of the batch data + let batch_data = std::mem::take(batch_rows); + let current_page_idx = *page_idx; + + // Prepare batch data + let (batch, lookup_entry) = + prepare_batch_data(batch_data, arrow_schema.clone(), current_page_idx).await?; + + lookup_entries.push(lookup_entry); + page_file.write_record_batch(batch).await?; + *page_idx += 1; + + Ok(()) +} + +pub(crate) fn part_page_data_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_PAGES_NAME) +} + +pub(crate) fn part_lookup_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME) +} + +/// A stream that reads the original training data back out of the index +/// +/// This is used for updating the index +struct IndexReaderStream { + reader: Arc, + batch_size: u64, + num_batches: u32, + batch_idx: u32, +} + +impl IndexReaderStream { + async fn new(reader: Arc, batch_size: u64) -> Self { + let num_batches = reader.num_batches(batch_size).await; + Self { + reader, + batch_size, + num_batches, + batch_idx: 0, + } + } +} + +impl Stream for IndexReaderStream { + type Item = BoxFuture<'static, Result>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + if this.batch_idx >= this.num_batches { + return std::task::Poll::Ready(None); + } + let batch_num = this.batch_idx; + this.batch_idx += 1; + let reader_copy = this.reader.clone(); + let batch_size = this.batch_size; + let read_task = async move { + reader_copy + .read_record_batch(batch_num as u64, batch_size) + .await + } + .boxed(); + std::task::Poll::Ready(Some(read_task)) + } +} + +/// Parameters for a btree index +#[derive(Debug, Serialize, Deserialize)] +pub struct BTreeParameters { + /// The number of rows to include in each zone + pub zone_size: Option, +} + +struct BTreeTrainingRequest { + parameters: BTreeParameters, + criteria: TrainingCriteria, +} + +impl BTreeTrainingRequest { + pub fn new(parameters: BTreeParameters) -> Self { + Self { + parameters, + // BTree indexes need data sorted by the value column + criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + } + } +} + +impl TrainingRequest for BTreeTrainingRequest { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn criteria(&self) -> &TrainingCriteria { + &self.criteria + } +} + +#[derive(Debug, Default)] +pub struct BTreeIndexPlugin; + +#[async_trait] +impl ScalarIndexPlugin for BTreeIndexPlugin { + fn new_training_request( + &self, + params: &str, + field: &Field, + ) -> Result> { + if field.data_type().is_nested() { + return Err(Error::InvalidInput { + source: "A btree index can only be created on a non-nested field.".into(), + location: location!(), + }); + } + + let params = serde_json::from_str::(params)?; + Ok(Box::new(BTreeTrainingRequest::new(params))) + } + + fn provides_exact_answer(&self) -> bool { + true + } + + fn version(&self) -> u32 { + 0 + } + + fn new_query_parser( + &self, + index_name: String, + _index_details: &prost_types::Any, + ) -> Option> { + Some(Box::new(SargableQueryParser::new(index_name, false))) + } + + async fn train_index( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + ) -> Result { + let request = request + .as_any() + .downcast_ref::() + .unwrap(); + let value_type = data + .schema() + .field_with_name(VALUE_COLUMN_NAME)? + .data_type() + .clone(); + let flat_index_trainer = FlatIndexMetadata::new(value_type); + train_btree_index( + data, + &flat_index_trainer, + index_store, + request + .parameters + .zone_size + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), + fragment_ids, + ) + .await?; + Ok(CreatedIndex { + index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), + index_version: BTREE_INDEX_VERSION, + }) + } + + async fn load_index( + &self, + index_store: Arc, + _index_details: &prost_types::Any, + frag_reuse_index: Option>, + cache: LanceCache, + ) -> Result> { + Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + use std::{collections::HashMap, sync::Arc}; + + use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; + use arrow_array::FixedSizeListArray; + use arrow_schema::DataType; + use datafusion::{ + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, + }; + use datafusion_common::{DataFusionError, ScalarValue}; + use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; + use deepsize::DeepSizeOf; + use futures::TryStreamExt; + use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; + use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; + use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use tempfile::tempdir; + + use crate::metrics::LocalMetricsCollector; + use crate::{ + metrics::NoOpMetricsCollector, + scalar::{ + btree::{BTreeIndex, BTREE_PAGES_NAME}, + flat::FlatIndexMetadata, + lance_format::LanceIndexStore, + IndexStore, SargableQuery, ScalarIndex, SearchResult, + }, + }; + + use super::{ + part_lookup_file_path, part_page_data_file_path, train_btree_index, OrderableScalarValue, + DEFAULT_BTREE_BATCH_SIZE, + }; + + #[test] + fn test_scalar_value_size() { + let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); + let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(0); 128])], + 128, + ), + ))) + .deep_size_of(); + + // deep_size_of should account for the rust type overhead + assert!(size_of_i32 > 4); + assert!(size_of_many_i32 > 128 * 4); + } + + #[tokio::test] + async fn test_null_ids() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Generate 50,000 rows of random data with 80% nulls + let stream = gen_batch() + .col( + "value", + array::rand::().with_nulls(&[true, false, false, false, false]), + ) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(5000), BatchCount::from(10)); + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(index.page_lookup.null_pages.len(), 10); + + let remap_dir = Arc::new(tempdir().unwrap()); + let remap_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(remap_dir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Remap with a no-op mapping. The remapped index should be identical to the original + index + .remap(&HashMap::default(), remap_store.as_ref()) + .await + .unwrap(); + + let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(remap_index.page_lookup, index.page_lookup); + + let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + + assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + + let original_data = original_pages + .read_record_batch(0, original_pages.num_rows() as u64) + .await + .unwrap(); + let remapped_data = remapped_pages + .read_record_batch(0, remapped_pages.num_rows() as u64) + .await + .unwrap(); + + assert_eq!(original_data, remapped_data); + } + + #[tokio::test] + async fn test_nan_ordering() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let values = vec![ + 0.0, + 1.0, + 2.0, + 3.0, + f64::NAN, + f64::NEG_INFINITY, + f64::INFINITY, + ]; + + // This is a bit overkill but we've had bugs in the past where DF's sort + // didn't agree with Arrow's sort so we do an end-to-end test here + // and use DF to sort the data like we would in a real dataset. + let data = gen_batch() + .col("value", array::cycle::(values.clone())) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(10), BatchCount::from(100)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + .await + .unwrap(); + + for (idx, value) in values.into_iter().enumerate() { + let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert_eq!( + result, + SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) + ); + } + } + + #[tokio::test] + async fn test_page_cache() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let data = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(1000), BatchCount::from(10)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load( + test_store, + None, + LanceCache::with_capacity(100 * 1024 * 1024), + ) + .await + .unwrap(); + + let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); + let metrics = LocalMetricsCollector::default(); + let query1 = index.search(&query, &metrics); + let query2 = index.search(&query, &metrics); + tokio::join!(query1, query2).0.unwrap(); + assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + } + + /// Test that fragment-based btree index construction produces exactly the same results as building a complete index + #[tokio::test] + async fn test_fragment_btree_index_consistency() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Method 1: Build complete index directly using the same data + // Create deterministic data for comparison - use 2 * DEFAULT_BTREE_BATCH_SIZE for testing + let total_count = (2 * DEFAULT_BTREE_BATCH_SIZE) as u64; + let full_data_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(total_count / 2), BatchCount::from(2)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using the same data split into fragments + // Create fragment 1 index - first half of the data (0 to DEFAULT_BTREE_BATCH_SIZE-1) + let half_count = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), // fragment_id = 1 + ) + .await + .unwrap(); + + // Create fragment 2 index - second half of the data (DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1) + let start_val = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_second_half: Vec = (start_val..end_val).collect(); + let row_ids_second_half: Vec = (start_val as u64..end_val as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_second_half)) + .col("_rowid", array::cycle::(row_ids_second_half)) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), // fragment_id = 2 + ) + .await + .unwrap(); + + // Merge the fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + // Test queries one by one to identify the exact problem + + // Test 1: Query for value 0 (should be in first page) + let query_0 = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_0 = full_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_0 = merged_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(full_result_0, merged_result_0, "Query for value 0 failed"); + + // Test 2: Query for value in middle of first batch (should be in first page) + let mid_first_batch = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_first = SargableQuery::Equals(ScalarValue::Int32(Some(mid_first_batch))); + let full_result_mid_first = full_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_first = merged_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_first, merged_result_mid_first, + "Query for value {} failed", + mid_first_batch + ); + + // Test 3: Query for first value in second batch (should be in second page) + let first_second_batch = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_first_second = + SargableQuery::Equals(ScalarValue::Int32(Some(first_second_batch))); + let full_result_first_second = full_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_first_second = merged_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_first_second, merged_result_first_second, + "Query for value {} failed", + first_second_batch + ); + + // Test 4: Query for value in middle of second batch (should be in second page) + let mid_second_batch = (DEFAULT_BTREE_BATCH_SIZE + DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_second = SargableQuery::Equals(ScalarValue::Int32(Some(mid_second_batch))); + + let full_result_mid_second = full_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_second = merged_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_second, merged_result_mid_second, + "Query for value {} failed", + mid_second_batch + ); } #[tokio::test] - async fn test_null_ids() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + async fn test_fragment_btree_index_boundary_queries() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - // Generate 50,000 rows of random data with 80% nulls - let stream = gen_batch() - .col( - "value", - array::rand::().with_nulls(&[true, false, false, false, false]), - ) + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Use 3 * DEFAULT_BTREE_BATCH_SIZE for more comprehensive boundary testing + let total_count = (3 * DEFAULT_BTREE_BATCH_SIZE) as u64; + + // Method 1: Build complete index directly + let full_data_gen = gen_batch() + .col("value", array::step::()) .col("_rowid", array::step::()) - .into_df_stream(RowCount::from(5000), BatchCount::from(10)); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + .into_df_stream(RowCount::from(total_count / 3), BatchCount::from(3)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using 3 fragments + // Fragment 1: 0 to DEFAULT_BTREE_BATCH_SIZE-1 + let fragment_size = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), + ) + .await + .unwrap(); + + // Fragment 2: DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val2 = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val2 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment2: Vec = (start_val2..end_val2).collect(); + let row_ids_fragment2: Vec = (start_val2 as u64..end_val2 as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_fragment2)) + .col("_rowid", array::cycle::(row_ids_fragment2)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), + ) + .await + .unwrap(); - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000) + // Fragment 3: 2*DEFAULT_BTREE_BATCH_SIZE to 3*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val3 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let end_val3 = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment3: Vec = (start_val3..end_val3).collect(); + let row_ids_fragment3: Vec = (start_val3 as u64..end_val3 as u64).collect(); + let fragment3_gen = gen_batch() + .col("value", array::cycle::(values_fragment3)) + .col("_rowid", array::cycle::(row_ids_fragment3)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment3_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment3_gen.schema(), + fragment3_gen, + )); + + train_btree_index( + fragment3_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![3]), + ) + .await + .unwrap(); + + // Merge all fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + part_page_data_file_path(3 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + part_lookup_file_path(3 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - assert_eq!(index.page_lookup.null_pages.len(), 10); + // === Boundary Value Tests === - let remap_dir = Arc::new(tempdir().unwrap()); - let remap_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(remap_dir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 1: Query minimum value (boundary: data start) + let query_min = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_min = full_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_min = merged_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_min, merged_result_min, + "Query for minimum value 0 failed" + ); - // Remap with a no-op mapping. The remapped index should be identical to the original - index - .remap(&HashMap::default(), remap_store.as_ref()) + // Test 2: Query maximum value (boundary: data end) + let max_val = (3 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val))); + let full_result_max = full_index + .search(&query_max, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_max = merged_index + .search(&query_max, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_max, merged_result_max, + "Query for maximum value {} failed", + max_val + ); - let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + // Test 3: Query fragment boundary value (last value of first fragment) + let fragment1_last = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag1_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment1_last))); + let full_result_frag1_last = full_index + .search(&query_frag1_last, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_frag1_last = merged_index + .search(&query_frag1_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag1_last, merged_result_frag1_last, + "Query for fragment 1 last value {} failed", + fragment1_last + ); - assert_eq!(remap_index.page_lookup, index.page_lookup); + // Test 4: Query fragment boundary value (first value of second fragment) + let fragment2_first = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_frag2_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_first))); + let full_result_frag2_first = full_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_first = merged_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_first, merged_result_frag2_first, + "Query for fragment 2 first value {} failed", + fragment2_first + ); - let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); - let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + // Test 5: Query fragment boundary value (last value of second fragment) + let fragment2_last = (2 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag2_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_last))); + let full_result_frag2_last = full_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_last = merged_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_last, merged_result_frag2_last, + "Query for fragment 2 last value {} failed", + fragment2_last + ); - assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + // Test 6: Query fragment boundary value (first value of third fragment) + let fragment3_first = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_frag3_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment3_first))); + let full_result_frag3_first = full_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag3_first = merged_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag3_first, merged_result_frag3_first, + "Query for fragment 3 first value {} failed", + fragment3_first + ); - let original_data = original_pages - .read_record_batch(0, original_pages.num_rows() as u64) + // === Non-existent Value Tests === + + // Test 7: Query value below minimum + let query_below_min = SargableQuery::Equals(ScalarValue::Int32(Some(-1))); + let full_result_below = full_index + .search(&query_below_min, &NoOpMetricsCollector) .await .unwrap(); - let remapped_data = remapped_pages - .read_record_batch(0, remapped_pages.num_rows() as u64) + let merged_result_below = merged_index + .search(&query_below_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_below, merged_result_below, + "Query for value below minimum (-1) failed" + ); + + // Test 8: Query value above maximum + let query_above_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val + 1))); + let full_result_above = full_index + .search(&query_above_max, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_above = merged_index + .search(&query_above_max, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_above, + merged_result_above, + "Query for value above maximum ({}) failed", + max_val + 1 + ); - assert_eq!(original_data, remapped_data); - } + // === Range Query Tests === - #[tokio::test] - async fn test_nan_ordering() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 9: Cross-fragment range query (from first fragment to second fragment) + let range_start = (DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let range_end = (DEFAULT_BTREE_BATCH_SIZE + 100) as i32; + let query_cross_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(range_end))), + ); + let full_result_cross = full_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_cross = merged_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_cross, merged_result_cross, + "Cross-fragment range query [{}, {}] failed", + range_start, range_end + ); - let values = vec![ - 0.0, - 1.0, - 2.0, - 3.0, - f64::NAN, - f64::NEG_INFINITY, - f64::INFINITY, - ]; + // Test 10: Range query within single fragment + let single_frag_start = 100i32; + let single_frag_end = 200i32; + let query_single_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(single_frag_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(single_frag_end))), + ); + let full_result_single = full_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_single = merged_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_single, merged_result_single, + "Single fragment range query [{}, {}] failed", + single_frag_start, single_frag_end + ); - // This is a bit overkill but we've had bugs in the past where DF's sort - // didn't agree with Arrow's sort so we do an end-to-end test here - // and use DF to sort the data like we would in a real dataset. - let data = gen_batch() - .col("value", array::cycle::(values.clone())) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(10), BatchCount::from(100)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + // Test 11: Large range query spanning all fragments + let large_range_start = 100i32; + let large_range_end = (3 * DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let query_large_range = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(large_range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(large_range_end))), + ); + let full_result_large = full_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_large = merged_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_large, merged_result_large, + "Large range query [{}, {}] failed", + large_range_start, large_range_end + ); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + // === Range Boundary Query Tests === - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) + // Test 12: Less than query (implemented using range query, from minimum to specified value) + let lt_val = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_lt = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(lt_val))), + ); + let full_result_lt = full_index + .search(&query_lt, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_lt = merged_index + .search(&query_lt, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lt, merged_result_lt, + "Less than query (<{}) failed", + lt_val + ); - let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + // Test 13: Greater than query (implemented using range query, from specified value to maximum) + let gt_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let max_range_val = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gt = SargableQuery::Range( + std::collections::Bound::Excluded(ScalarValue::Int32(Some(gt_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gt = full_index + .search(&query_gt, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gt = merged_index + .search(&query_gt, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_gt, merged_result_gt, + "Greater than query (>{}) failed", + gt_val + ); - for (idx, value) in values.into_iter().enumerate() { - let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); - let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) - ); - } + // Test 14: Less than or equal query (implemented using range query, including boundary value) + let lte_val = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_lte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Included(ScalarValue::Int32(Some(lte_val))), + ); + let full_result_lte = full_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_lte = merged_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lte, merged_result_lte, + "Less than or equal query (<={}) failed", + lte_val + ); + + // Test 15: Greater than or equal query (implemented using range query, including boundary value) + let gte_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(gte_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gte = full_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gte = merged_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_gte, merged_result_gte, + "Greater than or equal query (>={}) failed", + gte_val + ); + } + + #[test] + fn test_extract_partition_id() { + // Test valid partition file names + assert_eq!( + super::extract_partition_id("part_123_page_data.lance").unwrap(), + 123 + ); + assert_eq!( + super::extract_partition_id("part_456_page_lookup.lance").unwrap(), + 456 + ); + assert_eq!( + super::extract_partition_id("part_4294967296_page_data.lance").unwrap(), + 4294967296 + ); + + // Test invalid file names + assert!(super::extract_partition_id("invalid_filename.lance").is_err()); + assert!(super::extract_partition_id("part_abc_page_data.lance").is_err()); + assert!(super::extract_partition_id("part_123").is_err()); + assert!(super::extract_partition_id("part_").is_err()); } #[tokio::test] - async fn test_page_cache() { + async fn test_cleanup_partition_files() { + use crate::scalar::lance_format::LanceIndexStore; + use lance_core::cache::LanceCache; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use std::sync::Arc; + use tempfile::tempdir; + + // Create a test store let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + let test_store: Arc = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), Path::from_filesystem_path(tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - let data = gen_batch() - .col("value", array::step::()) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(1000), BatchCount::from(10)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + // Test files with different patterns + let lookup_files = vec![ + "part_123_page_lookup.lance".to_string(), + "invalid_lookup_file.lance".to_string(), + "part_456_page_lookup.lance".to_string(), + ]; - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) - .await - .unwrap(); + let page_files = vec![ + "part_123_page_data.lance".to_string(), + "invalid_page_file.lance".to_string(), + "part_456_page_data.lance".to_string(), + ]; - let index = BTreeIndex::load( - test_store, - None, - LanceCache::with_capacity(100 * 1024 * 1024), - ) - .await - .unwrap(); + // The cleanup function should handle both valid and invalid file patterns gracefully + // This test mainly verifies that the function doesn't panic and handles edge cases + super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; - let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); - let metrics = LocalMetricsCollector::default(); - let query1 = index.search(&query, &metrics); - let query2 = index.search(&query, &metrics); - tokio::join!(query1, query2).0.unwrap(); - assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + // If we get here without panicking, the cleanup function handled all cases correctly + assert!(true); } } diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 4edb1cb6a0a..a4506020782 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -163,6 +163,7 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -170,7 +171,8 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { source: "must provide training request created by new_training_request".into(), location: location!(), })?; - Self::train_inverted_index(data, index_store, request.parameters.clone(), None).await + Self::train_inverted_index(data, index_store, request.parameters.clone(), fragment_ids) + .await } /// Load an index from storage diff --git a/rust/lance-index/src/scalar/json.rs b/rust/lance-index/src/scalar/json.rs index 0b8a43efbe7..e36feaacfc7 100644 --- a/rust/lance-index/src/scalar/json.rs +++ b/rust/lance-index/src/scalar/json.rs @@ -768,6 +768,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -797,7 +798,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { )?; let target_index = target_plugin - .train_index(converted_stream, index_store, target_request) + .train_index(converted_stream, index_store, target_request, fragment_ids) .await?; let index_details = crate::pb::JsonIndexDetails { diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index 542aa2bc97a..64e932c47c5 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -398,6 +398,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let schema = data.schema(); let field = schema @@ -427,7 +428,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { let data = unnest_chunks(data)?; let bitmap_plugin = BitmapIndexPlugin; bitmap_plugin - .train_index(data, index_store, request) + .train_index(data, index_store, request, fragment_ids) .await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pb::LabelListIndexDetails::default()) diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index d8a95de1eeb..4df502ead09 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -368,7 +368,7 @@ pub mod tests { ) .unwrap(); btree_plugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -866,6 +866,7 @@ pub mod tests { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); @@ -911,7 +912,7 @@ pub mod tests { .new_training_request("{}", &Field::new(VALUE_COLUMN_NAME, DataType::Int32, false)) .unwrap(); BitmapIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -1399,7 +1400,7 @@ pub mod tests { ) .unwrap(); LabelListIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index ff559dd9292..586b0a4da9a 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -1285,6 +1285,7 @@ impl ScalarIndexPlugin for NGramIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_ngram_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 022da729f0c..3880aad4dbe 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -119,6 +119,7 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result; /// Returns true if the index returns an exact answer (e.g. not AtMost) diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 748ab003863..c9097cbb8cc 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -961,6 +961,7 @@ impl ScalarIndexPlugin for ZoneMapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + _fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 0742ff7f878..58b94f56318 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -71,6 +71,7 @@ impl BenchmarkFixture { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index cdebc399547..ccae72a4865 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -284,12 +284,11 @@ pub(super) async fn build_scalar_index( training_request.criteria(), None, train, - fragment_ids, + fragment_ids.clone(), ) .await?; - plugin - .train_index(training_data, &index_store, training_request) + .train_index(training_data, &index_store, training_request, fragment_ids) .await } From d1afc86794dab7529fa43af2fdf968a089951bf0 Mon Sep 17 00:00:00 2001 From: xloya Date: Fri, 5 Sep 2025 12:48:21 +0800 Subject: [PATCH 04/13] support btree distributely --- java/.project | 17 + .../org.eclipse.core.resources.prefs | 2 + java/.settings/org.eclipse.m2e.core.prefs | 4 + java/core/.classpath | 50 + java/core/.project | 23 + .../org.eclipse.core.resources.prefs | 5 + .../.settings/org.eclipse.jdt.apt.core.prefs | 2 + .../core/.settings/org.eclipse.jdt.core.prefs | 9 + .../core/.settings/org.eclipse.m2e.core.prefs | 4 + python/python/lance/dataset.py | 36 +- python/python/lance/lance/__init__.pyi | 4 +- python/python/tests/test_scalar_index.py | 416 +++- python/src/dataset.rs | 126 +- rust/lance-index/src/scalar/bitmap.rs | 1 + rust/lance-index/src/scalar/btree.rs | 2185 ++++++++++++++--- rust/lance-index/src/scalar/inverted.rs | 4 +- rust/lance-index/src/scalar/json.rs | 3 +- rust/lance-index/src/scalar/label_list.rs | 3 +- rust/lance-index/src/scalar/lance_format.rs | 7 +- rust/lance-index/src/scalar/ngram.rs | 1 + rust/lance-index/src/scalar/registry.rs | 1 + rust/lance-index/src/scalar/zonemap.rs | 1 + rust/lance/benches/scalar_index.rs | 1 + rust/lance/src/index/scalar.rs | 5 +- 24 files changed, 2575 insertions(+), 335 deletions(-) create mode 100644 java/.project create mode 100644 java/.settings/org.eclipse.core.resources.prefs create mode 100644 java/.settings/org.eclipse.m2e.core.prefs create mode 100644 java/core/.classpath create mode 100644 java/core/.project create mode 100644 java/core/.settings/org.eclipse.core.resources.prefs create mode 100644 java/core/.settings/org.eclipse.jdt.apt.core.prefs create mode 100644 java/core/.settings/org.eclipse.jdt.core.prefs create mode 100644 java/core/.settings/org.eclipse.m2e.core.prefs diff --git a/java/.project b/java/.project new file mode 100644 index 00000000000..9e430d58fe9 --- /dev/null +++ b/java/.project @@ -0,0 +1,17 @@ + + + lance-parent + + + + + + org.eclipse.m2e.core.maven2Builder + + + + + + org.eclipse.m2e.core.maven2Nature + + diff --git a/java/.settings/org.eclipse.core.resources.prefs b/java/.settings/org.eclipse.core.resources.prefs new file mode 100644 index 00000000000..99f26c0203a --- /dev/null +++ b/java/.settings/org.eclipse.core.resources.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +encoding/=UTF-8 diff --git a/java/.settings/org.eclipse.m2e.core.prefs b/java/.settings/org.eclipse.m2e.core.prefs new file mode 100644 index 00000000000..f897a7f1cb2 --- /dev/null +++ b/java/.settings/org.eclipse.m2e.core.prefs @@ -0,0 +1,4 @@ +activeProfiles= +eclipse.preferences.version=1 +resolveWorkspaceProjects=true +version=1 diff --git a/java/core/.classpath b/java/core/.classpath new file mode 100644 index 00000000000..5c8072ecc61 --- /dev/null +++ b/java/core/.classpath @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java/core/.project b/java/core/.project new file mode 100644 index 00000000000..4a9eedb6505 --- /dev/null +++ b/java/core/.project @@ -0,0 +1,23 @@ + + + lance-core + + + + + + org.eclipse.jdt.core.javabuilder + + + + + org.eclipse.m2e.core.maven2Builder + + + + + + org.eclipse.jdt.core.javanature + org.eclipse.m2e.core.maven2Nature + + diff --git a/java/core/.settings/org.eclipse.core.resources.prefs b/java/core/.settings/org.eclipse.core.resources.prefs new file mode 100644 index 00000000000..cdfe4f1b669 --- /dev/null +++ b/java/core/.settings/org.eclipse.core.resources.prefs @@ -0,0 +1,5 @@ +eclipse.preferences.version=1 +encoding//src/main/java=UTF-8 +encoding//src/test/java=UTF-8 +encoding//src/test/resources=UTF-8 +encoding/=UTF-8 diff --git a/java/core/.settings/org.eclipse.jdt.apt.core.prefs b/java/core/.settings/org.eclipse.jdt.apt.core.prefs new file mode 100644 index 00000000000..d4313d4b25e --- /dev/null +++ b/java/core/.settings/org.eclipse.jdt.apt.core.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.apt.aptEnabled=false diff --git a/java/core/.settings/org.eclipse.jdt.core.prefs b/java/core/.settings/org.eclipse.jdt.core.prefs new file mode 100644 index 00000000000..1b6e1ef22f9 --- /dev/null +++ b/java/core/.settings/org.eclipse.jdt.core.prefs @@ -0,0 +1,9 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 +org.eclipse.jdt.core.compiler.compliance=1.8 +org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled +org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning +org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore +org.eclipse.jdt.core.compiler.processAnnotations=disabled +org.eclipse.jdt.core.compiler.release=disabled +org.eclipse.jdt.core.compiler.source=1.8 diff --git a/java/core/.settings/org.eclipse.m2e.core.prefs b/java/core/.settings/org.eclipse.m2e.core.prefs new file mode 100644 index 00000000000..f897a7f1cb2 --- /dev/null +++ b/java/core/.settings/org.eclipse.m2e.core.prefs @@ -0,0 +1,4 @@ +activeProfiles= +eclipse.preferences.version=1 +resolveWorkspaceProjects=true +version=1 diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 376476e43af..54bd432201a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2731,8 +2731,40 @@ def prewarm_index(self, name: str): """ return self._ds.prewarm_index(name) - def merge_index_metadata(self, index_uuid: str): - return self._ds.merge_index_metadata(index_uuid) + def merge_index_metadata( + self, + index_uuid: str, + index_type: Union[ + Literal["BTREE"], + Literal["INVERTED"], + ], + prefetch_batch: Optional[int] = None, + ): + """ + Merge an index which not commit at present. + + Parameters + ---------- + index_uuid: str + The uuid of the index which want to merge. + index_type: Literal["BTREE", "INVERTED"] + The type of the index. + prefetch_batch: int, optional + The number of prefetch batches of sub-page files for merging. + Default 1. + """ + index_type = index_type.upper() + if index_type not in [ + "BTREE", + "INVERTED", + ]: + raise NotImplementedError( + ( + 'Only "BTREE" or "INVERTED" are supported for ' + f"merge index metadata. Received {index_type}", + ) + ) + return self._ds.merge_index_metadata(index_uuid, index_type, prefetch_batch) def session(self) -> Session: """ diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index c2a72b7b1b5..0bae8e2f1aa 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -282,7 +282,9 @@ class _Dataset: ): ... def drop_index(self, name: str): ... def prewarm_index(self, name: str): ... - def merge_index_metadata(self, index_uuid: str): ... + def merge_index_metadata( + self, index_uuid: str, index_type: str, prefetch_batch: Optional[int] = None + ): ... def count_fragments(self) -> int: ... def num_small_files(self, max_rows_per_group: int) -> int: ... def get_fragments(self) -> List[_Fragment]: ... diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index c2370a17a9e..5f92a1e4d11 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -1982,7 +1982,7 @@ def build_distributed_fts_index( ) # Merge the inverted index metadata - dataset.merge_index_metadata(index_id) + dataset.merge_index_metadata(index_id, index_type="INVERTED") # Create Index object for commit field_id = dataset.schema.get_field_index(column) @@ -2856,7 +2856,7 @@ def test_distribute_fts_index_build(tmp_path): print(f"Fragment {fragment_id} index created successfully") # Merge the inverted index metadata - ds.merge_index_metadata(index_id) + ds.merge_index_metadata(index_id, index_type="INVERTED") # Create an Index object using the new dataclass format from lance.dataset import Index @@ -2983,3 +2983,415 @@ def test_backward_compatibility_no_fragment_ids(tmp_path): results = ds.scanner(full_text_query=search_word).to_table() assert results.num_rows > 0 + + +def test_distribute_btree_index_build(tmp_path): + """ + Test distributed B-tree index build similar to test_distribute_fts_index_build. + This test creates B-tree indices on individual fragments and then + commits them as a single index. + """ + # Generate test dataset with multiple fragments + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=4, rows_per_fragment=10000 + ) + + import uuid + + index_id = str(uuid.uuid4()) + print(f"Using index ID: {index_id}") + index_name = "btree_multiple_fragment_idx" + + fragments = ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + for fragment in ds.get_fragments(): + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + # Create B-tree scalar index for each fragment + # Use the same index_name for all fragments (like in FTS test) + ds.create_scalar_index( + column="id", # Use integer column for B-tree + index_type="BTREE", + name=index_name, + replace=False, + fragment_uuid=index_id, + fragment_ids=[fragment_id], + ) + + # For fragment-level indexing, we expect the method to return successfully + # but not commit the index yet + print(f"Fragment {fragment_id} B-tree index created successfully") + + # Merge the B-tree index metadata + ds.merge_index_metadata(index_id, index_type="BTREE") + print(ds.uri) + + # Create an Index object using the new dataclass format + from lance.dataset import Index + + # Get the schema field for the indexed column + field_id = ds.schema.get_field_index("id") + + index = Index( + uuid=index_id, + name=index_name, + fields=[field_id], # Use field index instead of field object + dataset_version=ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Create the index operation + create_index_op = lance.LanceOperation.CreateIndex( + new_indices=[index], + removed_indices=[], + ) + + # Commit the index + ds_committed = lance.LanceDataset.commit( + ds.uri, + create_index_op, + read_version=ds.version, + ) + + print("Successfully committed multiple fragment B-tree index") + + # Verify the index was created and is functional + indices = ds_committed.list_indices() + assert len(indices) > 0, "No indices found after commit" + + # Find our index + our_index = None + for idx in indices: + if idx["name"] == index_name: + our_index = idx + break + + assert our_index is not None, f"Index '{index_name}' not found in indices list" + assert our_index["type"] == "BTree", ( + f"Expected BTree index, got {our_index['type']}" + ) + + # Test that the index works for searching + # Test exact equality queries + test_id = 100 # Should be in first fragment + results = ds_committed.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + print(f"Search for id = {test_id} returned {results.num_rows} results") + assert results.num_rows > 0, f"No results found for id = {test_id}" + + # Test range queries across fragments + results_range = ds_committed.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + print(f"Range query returned {results_range.num_rows} results") + assert results_range.num_rows > 0, "No results found for range query" + + # Compare with complete index results to ensure consistency + # Create a reference dataset with complete index + reference_ds = generate_multi_fragment_dataset( + tmp_path / "reference", num_fragments=4, rows_per_fragment=10000 + ) + + # Create complete B-tree index for comparison + reference_ds.create_scalar_index( + column="id", + index_type="BTREE", + name="reference_btree_idx", + ) + + # Compare exact query results + reference_results = reference_ds.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + assert results.num_rows == reference_results.num_rows, ( + f"Distributed index returned {results.num_rows} results, " + f"but complete index returned {reference_results.num_rows} results" + ) + + # Compare range query results + reference_range_results = reference_ds.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + assert results_range.num_rows == reference_range_results.num_rows, ( + f"Distributed index range query returned {results_range.num_rows} results, " + f"but complete index returned {reference_range_results.num_rows} results" + ) + + +def test_btree_precise_query_comparison(tmp_path): + """ + Precise comparison test between fragment-level B-tree index and complete + B-tree index. + This test creates identical datasets and compares query results in detail. + """ + # Test configuration + num_fragments = 3 + rows_per_fragment = 10000 + total_rows = num_fragments * rows_per_fragment + + print( + f"Creating datasets with {num_fragments} fragments," + f" {rows_per_fragment} rows each" + ) + + # Create dataset for fragment-level indexing + fragment_ds = generate_multi_fragment_dataset( + tmp_path / "fragment", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + # Create dataset for complete indexing (same data structure) + complete_ds = generate_multi_fragment_dataset( + tmp_path / "complete", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + import uuid + + # Build fragment-level B-tree index + fragment_index_id = str(uuid.uuid4()) + fragment_index_name = "fragment_btree_precise_test" + + fragments = fragment_ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + # Create fragment-level indices + for fragment in fragments: + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + fragment_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=fragment_index_name, + replace=False, + fragment_uuid=fragment_index_id, + fragment_ids=[fragment_id], + ) + + # Merge fragment indices + fragment_ds.merge_index_metadata(fragment_index_id, index_type="BTREE") + + # Create Index object for fragment-based index + from lance.dataset import Index + + field_id = fragment_ds.schema.get_field_index("id") + + fragment_index = Index( + uuid=fragment_index_id, + name=fragment_index_name, + fields=[field_id], + dataset_version=fragment_ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Commit fragment-based index + create_fragment_index_op = lance.LanceOperation.CreateIndex( + new_indices=[fragment_index], + removed_indices=[], + ) + + fragment_ds_committed = lance.LanceDataset.commit( + fragment_ds.uri, + create_fragment_index_op, + read_version=fragment_ds.version, + ) + + # Build complete B-tree index + complete_index_name = "complete_btree_precise_test" + complete_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=complete_index_name, + ) + + print("Both indices created successfully") + + # Detailed query comparison tests + test_cases = [ + # Test 1: Boundary values at fragment edges + {"name": "First value", "filter": "id = 0"}, + {"name": "Fragment 0 last value", "filter": f"id = {rows_per_fragment - 1}"}, + {"name": "Fragment 1 first value", "filter": f"id = {rows_per_fragment}"}, + { + "name": "Fragment 1 last value", + "filter": f"id = {2 * rows_per_fragment - 1}", + }, + {"name": "Fragment 2 first value", "filter": f"id = {2 * rows_per_fragment}"}, + {"name": "Last value", "filter": f"id = {total_rows - 1}"}, + # Test 2: Values in the middle of fragments + {"name": "Fragment 0 middle", "filter": f"id = {rows_per_fragment // 2}"}, + { + "name": "Fragment 1 middle", + "filter": f"id = {rows_per_fragment + rows_per_fragment // 2}", + }, + { + "name": "Fragment 2 middle", + "filter": f"id = {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 3: Range queries within single fragments + {"name": "Range within fragment 0", "filter": "id >= 10 AND id < 20"}, + { + "name": "Range within fragment 1", + "filter": f"id >= {rows_per_fragment + 10}" + f" AND id < {rows_per_fragment + 20}", + }, + { + "name": "Range within fragment 2", + "filter": f"id >= {2 * rows_per_fragment + 10}" + f" AND id < {2 * rows_per_fragment + 20}", + }, + # Test 4: Range queries spanning multiple fragments + { + "name": "Cross fragment 0-1", + "filter": f"id >= {rows_per_fragment - 5} AND id < {rows_per_fragment + 5}", + }, + { + "name": "Cross fragment 1-2", + "filter": f"id >= {2 * rows_per_fragment - 5}" + f" AND id < {2 * rows_per_fragment + 5}", + }, + { + "name": "Cross all fragments", + "filter": f"id >= {rows_per_fragment // 2} AND" + f" id < {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 5: Edge cases + {"name": "Non-existent small value", "filter": "id = -1"}, + {"name": "Non-existent large value", "filter": f"id = {total_rows + 100}"}, + {"name": "Large range", "filter": f"id >= 0 AND id < {total_rows}"}, + # Test 6: Comparison operators + {"name": "Less than boundary", "filter": f"id < {rows_per_fragment}"}, + { + "name": "Greater than boundary", + "filter": f"id > {2 * rows_per_fragment - 1}", + }, + {"name": "Less than or equal", "filter": f"id <= {rows_per_fragment + 50}"}, + {"name": "Greater than or equal", "filter": f"id >= {rows_per_fragment + 50}"}, + ] + + print(f"\nRunning {len(test_cases)} detailed comparison tests:") + + for i, test_case in enumerate(test_cases, 1): + test_name = test_case["name"] + filter_expr = test_case["filter"] + + print(f" {i:2d}. Testing {test_name}: {filter_expr}") + + # Query fragment-based index + fragment_results = fragment_ds_committed.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Query complete index + complete_results = complete_ds.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Compare row counts + assert fragment_results.num_rows == complete_results.num_rows, ( + f"Test '{test_name}' failed: Fragment index " + f"returned {fragment_results.num_rows} rows, " + f"but complete index returned {complete_results.num_rows}" + f" rows for filter: {filter_expr}" + ) + + # Compare actual results if there are any + if fragment_results.num_rows > 0: + # Sort both results by id for comparison + fragment_ids = sorted(fragment_results.column("id").to_pylist()) + complete_ids = sorted(complete_results.column("id").to_pylist()) + + assert fragment_ids == complete_ids, ( + f"Test '{test_name}' failed: Fragment index" + f" returned different IDs than complete index. " + f"Fragment IDs:" + f" {fragment_ids[:10]}{'...' if len(fragment_ids) > 10 else ''}, " + f"Complete IDs:" + f" {complete_ids[:10]}{'...' if len(complete_ids) > 10 else ''}" + ) + + print(f" āœ“ Passed ({fragment_results.num_rows} rows)") + + print(f"\nāœ… All {len(test_cases)} precision tests passed!") + print( + "Fragment-level B-tree index produces identical results" + " to complete B-tree index." + ) + + +def test_btree_fragment_ids_parameter_validation(tmp_path): + """ + Test validation of fragment_ids parameter for B-tree indices. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # Test with valid fragment IDs + fragments = ds.get_fragments() + valid_fragment_id = fragments[0].fragment_id + + # This should work without errors + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[valid_fragment_id], + ) + + # Test with invalid fragment ID (should handle gracefully) + try: + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[999999], # Non-existent fragment ID + ) + except Exception as e: + # It's acceptable for this to fail with an appropriate error + print(f"Expected error for invalid fragment ID: {e}") + + +def test_btree_backward_compatibility_no_fragment_ids(tmp_path): + """ + Test that B-tree indexing remains backward compatible + when fragment_ids is not provided. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # This should work exactly as before (full dataset indexing) + ds.create_scalar_index( + column="id", + index_type="BTREE", + name="full_dataset_btree_idx", + ) + + # Verify the index was created + indices = ds.list_indices() + assert len(indices) == 1 + assert indices[0]["name"] == "full_dataset_btree_idx" + assert indices[0]["type"] == "BTree" + + # Test that the index works + results = ds.scanner(filter="id = 50").to_table() + assert results.num_rows > 0 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 83da7f26dc9..f3d0fd10e83 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1670,47 +1670,111 @@ impl Dataset { .infer_error() } - #[pyo3(signature = (index_uuid))] - fn merge_index_metadata(&self, index_uuid: &str) -> PyResult<()> { + #[pyo3(signature = (index_uuid, index_type, prefetch_batch))] + fn merge_index_metadata( + &self, + index_uuid: &str, + index_type: &str, + prefetch_batch: Option, + ) -> PyResult<()> { RT.block_on(None, async { + let index_type = index_type.to_uppercase(); + let idx_type = match index_type.as_str() { + "BTREE" => IndexType::BTree, + "INVERTED" => IndexType::Inverted, + _ => { + return Err(Error::InvalidInput { + source: format!( + "Index type {} is not supported.", + index_type + ).into(), + location: location!(), + }); + } + }; + let store = LanceIndexStore::from_dataset_for_new(self.ds.as_ref(), index_uuid)?; let index_dir = self.ds.indices_dir().child(index_uuid); + if idx_type == IndexType::Inverted { + // List all partition metadata files in the index directory + let mut part_metadata_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") + { + part_metadata_files.push(file_name.to_string()); + } + } + Err(_) => continue, + } + } + + if part_metadata_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {}", + index_dir + ) + .into(), + location: location!(), + }); + } - // List all partition metadata files in the index directory - let mut part_metadata_files = Vec::new(); - let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); - - while let Some(item) = list_stream.next().await { - match item { - Ok(meta) => { - let file_name = meta.location.filename().unwrap_or_default(); - // Filter files matching the pattern part_*_metadata.lance - if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") - { - part_metadata_files.push(file_name.to_string()); + // Call merge_metadata_files function for inverted index + lance_index::scalar::inverted::builder::merge_metadata_files( + Arc::new(store), + &part_metadata_files, + ) + .await + } else { + // List all partition page / lookup files in the index directory + let mut part_page_files = Vec::new(); + let mut part_lookup_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_page_data.lance") + { + part_page_files.push(file_name.to_string()); + } + if file_name.starts_with("part_") && file_name.ends_with("_page_lookup.lance") + { + part_lookup_files.push(file_name.to_string()); + } } + Err(_) => continue, } - Err(_) => continue, } - } + if part_page_files.is_empty() || part_lookup_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {} (page_files: {}, lookup_files: {})", + index_dir, part_page_files.len(), part_lookup_files.len() + ) + .into(), + location: location!(), + }); + } - if part_metadata_files.is_empty() { - return Err(Error::InvalidInput { - source: format!( - "No partition metadata files found in index directory: {}", - index_dir - ) - .into(), - location: location!(), - }); + // Call merge_metadata_files function for btree index + lance_index::scalar::btree::merge_metadata_files( + Arc::new(store), + &part_page_files, + &part_lookup_files, + prefetch_batch, + ).await } - // Call merge_metadata_files function for inverted index - lance_index::scalar::inverted::builder::merge_metadata_files( - Arc::new(store), - &part_metadata_files, - ) - .await + })? .map_err(|err| PyValueError::new_err(err.to_string())) } diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 1b5d0d530bd..09dde5297b4 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -528,6 +528,7 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_bitmap_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index b2760fe214d..9c3e6685703 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -4,10 +4,10 @@ use std::{ any::Any, cmp::Ordering, - collections::{BTreeMap, BinaryHeap, HashMap}, + collections::{BTreeMap, BinaryHeap, HashMap, VecDeque}, fmt::{Debug, Display}, ops::Bound, - sync::Arc, + sync::{Arc, LazyLock}, }; use super::{ @@ -38,7 +38,7 @@ use deepsize::DeepSizeOf; use futures::{ future::BoxFuture, stream::{self}, - FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, + Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, }; use lance_core::{ cache::{CacheKey, LanceCache}, @@ -54,16 +54,37 @@ use lance_datafusion::{ chunker::chunk_concat_stream, exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; -use log::debug; +use log::{debug, warn}; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize, Serializer}; use snafu::location; +use tokio::runtime::{Builder, Runtime}; use tracing::info; const BTREE_LOOKUP_NAME: &str = "page_lookup.lance"; const BTREE_PAGES_NAME: &str = "page_data.lance"; pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; const BATCH_SIZE_META_KEY: &str = "batch_size"; + +/// Global thread pool for B-tree prefetch operations +static BTREE_PREFETCH_RUNTIME: LazyLock = LazyLock::new(|| { + Builder::new_multi_thread() + .worker_threads(get_num_compute_intensive_cpus()) + .max_blocking_threads(get_num_compute_intensive_cpus()) + .thread_name("lance-btree-prefetch") + .enable_time() + .build() + .expect("Failed to create B-tree prefetch runtime") +}); + +/// Spawn a prefetch task on the B-tree thread pool +fn spawn_btree_prefetch(future: F) -> tokio::task::JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + BTREE_PREFETCH_RUNTIME.spawn(future) +} const BTREE_INDEX_VERSION: u32 = 0; pub(crate) const BTREE_VALUES_COLUMN: &str = "values"; pub(crate) const BTREE_IDS_COLUMN: &str = "ids"; @@ -1231,6 +1252,7 @@ impl ScalarIndex for BTreeIndex { self.sub_index.as_ref(), dest_store, DEFAULT_BTREE_BATCH_SIZE, + None, ) .await?; @@ -1366,10 +1388,33 @@ pub async fn train_btree_index( sub_index_trainer: &dyn BTreeSubIndex, index_store: &dyn IndexStore, batch_size: u64, + fragment_ids: Option>, ) -> Result<()> { - let mut sub_index_file = index_store - .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) - .await?; + let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| { + if !frag_ids.is_empty() { + // Create a mask with fragment_id in high 32 bits for distributed indexing + // This mask is used to filter partitions belonging to specific fragments + // If multiple fragments processed, use first fragment_id <<32 as mask + Some((frag_ids[0] as u64) << 32) + } else { + None + } + }); + + let mut sub_index_file; + if fragment_mask.is_none() { + sub_index_file = index_store + .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) + .await?; + } else { + sub_index_file = index_store + .new_index_file( + part_page_data_file_path(fragment_mask.unwrap()).as_str(), + sub_index_trainer.schema().clone(), + ) + .await?; + } + let mut encoded_batches = Vec::new(); let mut batch_idx = 0; @@ -1393,385 +1438,1945 @@ pub async fn train_btree_index( file_schema .metadata .insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); - let mut btree_index_file = index_store - .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) - .await?; + let mut btree_index_file; + if fragment_mask.is_none() { + btree_index_file = index_store + .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) + .await?; + } else { + btree_index_file = index_store + .new_index_file( + part_lookup_file_path(fragment_mask.unwrap()).as_str(), + Arc::new(file_schema), + ) + .await?; + } btree_index_file.write_record_batch(record_batch).await?; btree_index_file.finish().await?; Ok(()) } -/// A stream that reads the original training data back out of the index -/// -/// This is used for updating the index -struct IndexReaderStream { - reader: Arc, - batch_size: u64, - num_batches: u32, - batch_idx: u32, +/// Extract partition ID from partition file name +/// Expected format: "part_{partition_id}_{suffix}.lance" +fn extract_partition_id(filename: &str) -> Result { + if !filename.starts_with("part_") { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + let parts: Vec<&str> = filename.split('_').collect(); + if parts.len() < 3 { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + parts[1].parse::().map_err(|_| Error::Internal { + message: format!("Failed to parse partition ID from filename: {}", filename), + location: location!(), + }) } -impl IndexReaderStream { - async fn new(reader: Arc, batch_size: u64) -> Self { - let num_batches = reader.num_batches(batch_size).await; - Self { - reader, - batch_size, - num_batches, - batch_idx: 0, +/// Merge multiple partition page / lookup files into a complete metadata file +/// +/// In a distributed environment, each worker node writes partition page / lookup files for the partitions it processes, +/// and this function merges these files into a final metadata file. +pub async fn merge_metadata_files( + store: Arc, + part_page_files: &[String], + part_lookup_files: &[String], + prefetch_batch: Option, +) -> Result<()> { + if part_lookup_files.is_empty() || part_page_files.is_empty() { + return Err(Error::Internal { + message: "No partition files provided for merging".to_string(), + location: location!(), + }); + } + + // Step 1: Create lookup map for page files by partition ID + let mut page_files_map = HashMap::new(); + for page_file in part_page_files { + let partition_id = extract_partition_id(page_file)?; + page_files_map.insert(partition_id, page_file); + } + + // Step 2: Validate that all lookup files have corresponding page files + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + if !page_files_map.contains_key(&partition_id) { + return Err(Error::Internal { + message: format!( + "No corresponding page file found for lookup file: {} (partition_id: {})", + lookup_file, partition_id + ), + location: location!(), + }); } } + + // Step 3: Extract metadata from lookup files + let first_lookup_reader = store.open_index_file(&part_lookup_files[0]).await?; + let batch_size = first_lookup_reader + .schema() + .metadata + .get(BATCH_SIZE_META_KEY) + .map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE)) + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE); + + // Get the value type from lookup schema (min column) + let lookup_batch = first_lookup_reader.read_range(0..1, None).await?; + let value_type = lookup_batch.column(0).data_type().clone(); + + // Get page schema first + let partition_id = extract_partition_id(part_lookup_files[0].as_str())?; + let page_file = page_files_map.get(&partition_id).unwrap(); + let page_reader = store.open_index_file(page_file).await?; + let page_schema = page_reader.schema().clone(); + + let arrow_schema = Arc::new(Schema::from(&page_schema)); + let mut page_file = store + .new_index_file(BTREE_PAGES_NAME, arrow_schema.clone()) + .await?; + + let mut prefetch_config = PrefetchConfig::default(); + if prefetch_batch.is_some() { + prefetch_config = prefetch_config.with_prefetch_batch(prefetch_batch.unwrap()); + } + + let lookup_entries = merge_page( + part_lookup_files, + &page_files_map, + &store, + batch_size, + &mut page_file, + arrow_schema.clone(), + prefetch_config, + ) + .await?; + + page_file.finish().await?; + + // Step 4: Generate new lookup file based on reorganized pages + // Add batch_size to schema metadata + let mut metadata = HashMap::new(); + metadata.insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); + + let lookup_schema_with_metadata = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("min", value_type.clone(), true), + Field::new("max", value_type, true), + Field::new("null_count", DataType::UInt32, false), + Field::new("page_idx", DataType::UInt32, false), + ], + metadata, + )); + + let lookup_batch = RecordBatch::try_new( + lookup_schema_with_metadata.clone(), + vec![ + ScalarValue::iter_to_array(lookup_entries.iter().map(|(min, _, _, _)| min.clone()))?, + ScalarValue::iter_to_array(lookup_entries.iter().map(|(_, max, _, _)| max.clone()))?, + Arc::new(UInt32Array::from_iter_values( + lookup_entries + .iter() + .map(|(_, _, null_count, _)| *null_count), + )), + Arc::new(UInt32Array::from_iter_values( + lookup_entries.iter().map(|(_, _, _, page_idx)| *page_idx), + )), + ], + )?; + + let mut lookup_file = store + .new_index_file(BTREE_LOOKUP_NAME, lookup_schema_with_metadata) + .await?; + lookup_file.write_record_batch(lookup_batch).await?; + lookup_file.finish().await?; + + // After successfully writing the merged files, delete all partition files + // Only perform deletion after files are successfully written, ensuring debug information is not lost in case of failure + cleanup_partition_files(&store, part_lookup_files, part_page_files).await; + + Ok(()) } -impl Stream for IndexReaderStream { - type Item = BoxFuture<'static, Result>; +/// Clean up partition files after successful merge +/// +/// This function safely deletes partition lookup and page files after a successful merge operation. +/// File deletion failures are logged but do not affect the overall success of the merge operation. +async fn cleanup_partition_files( + store: &Arc, + part_lookup_files: &[String], + part_page_files: &[String], +) { + // Clean up partition lookup files + for file_name in part_lookup_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_lookup.lance", + "partition lookup", + ) + .await; + } - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - if this.batch_idx >= this.num_batches { - return std::task::Poll::Ready(None); - } - let batch_num = this.batch_idx; - this.batch_idx += 1; - let reader_copy = this.reader.clone(); - let batch_size = this.batch_size; - let read_task = async move { - reader_copy - .read_record_batch(batch_num as u64, batch_size) - .await - } - .boxed(); - std::task::Poll::Ready(Some(read_task)) + // Clean up partition page files + for file_name in part_page_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_data.lance", + "partition page", + ) + .await; } } -/// Parameters for a btree index -#[derive(Debug, Serialize, Deserialize)] -pub struct BTreeParameters { - /// The number of rows to include in each zone - pub zone_size: Option, +/// Helper function to clean up a single partition file +/// +/// Performs safety checks on the filename pattern before attempting deletion. +async fn cleanup_single_file( + store: &Arc, + file_name: &str, + expected_prefix: &str, + expected_suffix: &str, + file_type: &str, +) { + // Ensure we only delete files that match the expected pattern (safety check) + if file_name.starts_with(expected_prefix) && file_name.ends_with(expected_suffix) { + match store.delete_index_file(file_name).await { + Ok(()) => { + debug!("Successfully deleted {} file: {}", file_type, file_name); + } + Err(e) => { + // File deletion failures should not affect the overall success of the function + // Log the error but continue processing other files + warn!( + "Failed to delete {} file '{}': {}. \ + This does not affect the merge operation, but may leave \ + partition files that should be cleaned up manually.", + file_type, file_name, e + ); + } + } + } else { + // If the filename doesn't match the expected format, log a warning but don't attempt deletion + warn!( + "Skipping deletion of file '{}' as it does not match the expected \ + {} file pattern ({}*{})", + file_name, file_type, expected_prefix, expected_suffix + ); + } } -struct BTreeTrainingRequest { - parameters: BTreeParameters, - criteria: TrainingCriteria, +/// Prefetch configuration for partition iterators +#[derive(Debug, Clone)] +pub struct PrefetchConfig { + /// Number of batches to prefetch ahead (0 means no prefetching) + pub prefetch_batches: usize, } -impl BTreeTrainingRequest { - pub fn new(parameters: BTreeParameters) -> Self { +impl Default for PrefetchConfig { + fn default() -> Self { Self { - parameters, - // BTree indexes need data sorted by the value column - criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + prefetch_batches: 1, } } } -impl TrainingRequest for BTreeTrainingRequest { - fn as_any(&self) -> &dyn std::any::Any { - self +impl PrefetchConfig { + /// Set the prefetch batch count + pub fn with_prefetch_batch(&self, batch_count: usize) -> Self { + Self { + prefetch_batches: batch_count, + } } +} - fn criteria(&self) -> &TrainingCriteria { - &self.criteria - } +/// Buffer entry for prefetch queue +#[derive(Debug)] +struct BufferEntry { + batch: RecordBatch, + start_row: usize, + end_row: usize, } -#[derive(Debug, Default)] -pub struct BTreeIndexPlugin; +/// Running prefetch task information +#[derive(Debug)] +struct RunningPrefetchTask { + /// Task handle + handle: tokio::task::JoinHandle<()>, + /// Range being prefetched + range: std::ops::Range, +} -#[async_trait] -impl ScalarIndexPlugin for BTreeIndexPlugin { - fn new_training_request( - &self, - params: &str, - field: &Field, - ) -> Result> { - if field.data_type().is_nested() { - return Err(Error::InvalidInput { - source: "A btree index can only be created on a non-nested field.".into(), - location: location!(), - }); - } +/// Check if two ranges overlap +fn ranges_overlap(range1: &std::ops::Range, range2: &std::ops::Range) -> bool { + range1.start < range2.end && range2.start < range1.end +} - let params = serde_json::from_str::(params)?; - Ok(Box::new(BTreeTrainingRequest::new(params))) - } +/// Prefetch state for a partition using task-based prefetching +struct PartitionPrefetchState { + /// Queue of prefetched data + buffer: Arc>>, + /// Reader for this partition + reader: Arc, + /// Total rows in this partition + total_rows: usize, + /// Queue of running prefetch tasks with their ranges + running_tasks: Arc>>, + /// Next position to schedule for prefetch + next_prefetch_position: Arc>, +} - fn provides_exact_answer(&self) -> bool { - true - } +/// Manager for coordinating task-based prefetch across multiple partitions +pub struct PrefetchManager { + /// Prefetch state per partition + partition_states: HashMap, + /// Prefetch configuration + config: PrefetchConfig, +} - fn version(&self) -> u32 { - 0 +impl PrefetchManager { + /// Create a new prefetch manager + pub fn new(config: PrefetchConfig) -> Self { + Self { + partition_states: HashMap::new(), + config, + } } - fn new_query_parser( - &self, - index_name: String, - _index_details: &prost_types::Any, - ) -> Option> { - Some(Box::new(SargableQueryParser::new(index_name, false))) - } + /// Initialize a partition for task-based prefetching + pub fn initialize_partition(&mut self, partition_id: u64, reader: Arc) { + let total_rows = reader.num_rows(); + let buffer = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let running_tasks = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let next_prefetch_position = Arc::new(tokio::sync::Mutex::new(0)); - async fn train_index( - &self, - data: SendableRecordBatchStream, - index_store: &dyn IndexStore, - request: Box, - ) -> Result { - let request = request - .as_any() - .downcast_ref::() - .unwrap(); - let value_type = data - .schema() - .field_with_name(VALUE_COLUMN_NAME)? - .data_type() - .clone(); - let flat_index_trainer = FlatIndexMetadata::new(value_type); - train_btree_index( - data, - &flat_index_trainer, - index_store, - request - .parameters - .zone_size - .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), - ) - .await?; - Ok(CreatedIndex { - index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), - index_version: BTREE_INDEX_VERSION, - }) - } + let state = PartitionPrefetchState { + buffer, + reader, + total_rows, + running_tasks, + next_prefetch_position, + }; - async fn load_index( - &self, - index_store: Arc, - _index_details: &prost_types::Any, - frag_reuse_index: Option>, - cache: LanceCache, - ) -> Result> { - Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + self.partition_states.insert(partition_id, state); + debug!( + "Initialized partition {} for task-based prefetching", + partition_id + ); } -} -#[cfg(test)] -mod tests { - use std::sync::atomic::Ordering; - use std::{collections::HashMap, sync::Arc}; + /// Submit a prefetch task for a partition to the thread pool + pub async fn submit_prefetch_task(&self, partition_id: u64, batch_size: usize) -> Result<()> { + if self.config.prefetch_batches == 0 { + return Ok(()); + } - use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; - use arrow_array::FixedSizeListArray; - use arrow_schema::DataType; - use datafusion::{ - execution::{SendableRecordBatchStream, TaskContext}, - physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, - }; - use datafusion_common::{DataFusionError, ScalarValue}; - use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; - use deepsize::DeepSizeOf; - use futures::TryStreamExt; - use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; - use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; - use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; - use lance_io::object_store::ObjectStore; - use object_store::path::Path; - use tempfile::tempdir; + let Some(state) = self.partition_states.get(&partition_id) else { + return Ok(()); + }; - use crate::metrics::LocalMetricsCollector; - use crate::{ - metrics::NoOpMetricsCollector, - scalar::{ - btree::{BTreeIndex, BTREE_PAGES_NAME}, - flat::FlatIndexMetadata, - lance_format::LanceIndexStore, - IndexStore, SargableQuery, ScalarIndex, SearchResult, - }, - }; + let reader = state.reader.clone(); + let buffer = state.buffer.clone(); + let running_tasks = state.running_tasks.clone(); + let next_prefetch_position = state.next_prefetch_position.clone(); + let total_rows = state.total_rows; + let effective_batch_size = self.config.prefetch_batches * batch_size; - use super::{train_btree_index, OrderableScalarValue}; + const MAX_BUFFER_SIZE: usize = 4; + const MAX_RUNNING_TASKS: usize = 2; - #[test] - fn test_scalar_value_size() { - let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); - let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( - FixedSizeListArray::from_iter_primitive::( - vec![Some(vec![Some(0); 128])], - 128, - ), - ))) - .deep_size_of(); + // Clean up completed tasks and check limits + { + let mut tasks_guard = running_tasks.lock().await; - // deep_size_of should account for the rust type overhead - assert!(size_of_i32 > 4); - assert!(size_of_many_i32 > 128 * 4); + // Remove completed tasks from the front + while let Some(task) = tasks_guard.front() { + if task.handle.is_finished() { + tasks_guard.pop_front(); + } else { + break; + } + } + + // Check if we have too many running tasks + if tasks_guard.len() >= MAX_RUNNING_TASKS { + debug!( + "Skipping prefetch for partition {} - too many running tasks ({})", + partition_id, + tasks_guard.len() + ); + return Ok(()); + } + + // Check if any running task already covers to the end of file + for task in tasks_guard.iter() { + if task.range.end >= total_rows { + debug!( + "Skipping prefetch for partition {} - task already covers to EOF (range {}..{})", + partition_id, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // Check if buffer is full + { + let buffer_guard = buffer.lock().await; + if buffer_guard.len() >= MAX_BUFFER_SIZE { + debug!( + "Skipping prefetch for partition {} - buffer full", + partition_id + ); + return Ok(()); + } + } + + // Determine the next range to prefetch + let next_range = { + let mut pos_guard = next_prefetch_position.lock().await; + let start_pos = *pos_guard; + + if start_pos >= total_rows { + debug!( + "Skipping prefetch for partition {} - no more data to prefetch", + partition_id + ); + return Ok(()); + } + + let end_pos = std::cmp::min(start_pos + effective_batch_size, total_rows); + *pos_guard = end_pos; // Update next prefetch position + start_pos..end_pos + }; + + // Check if this range is already being prefetched + { + let tasks_guard = running_tasks.lock().await; + + // Check for range overlap + for task in tasks_guard.iter() { + if ranges_overlap(&task.range, &next_range) { + debug!( + "Skipping prefetch for partition {} - range {}..{} overlaps with running task {}..{}", + partition_id, next_range.start, next_range.end, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // All checks passed, create the actual prefetch task (only this part is async) + let range_clone = next_range.clone(); + let running_tasks_for_cleanup = running_tasks.clone(); + + let prefetch_task = spawn_btree_prefetch(async move { + // Perform the actual read + match reader.read_range(range_clone.clone(), None).await { + Ok(batch) => { + let entry = BufferEntry { + batch, + start_row: range_clone.start, + end_row: range_clone.end, + }; + + // Add to buffer + { + let mut buffer_guard = buffer.lock().await; + buffer_guard.push_back(entry); + } + + debug!( + "Prefetched {} rows ({}..{}) for partition {}", + range_clone.end - range_clone.start, + range_clone.start, + range_clone.end, + partition_id + ); + } + Err(err) => { + warn!( + "Prefetch task failed for partition {} range {}..{}: {}", + partition_id, range_clone.start, range_clone.end, err + ); + } + } + + // Remove this task from running tasks when completed + { + let mut tasks_guard = running_tasks_for_cleanup.lock().await; + tasks_guard.retain(|task| !task.handle.is_finished()); + } + }); + + // Add the task to running tasks + { + let mut tasks_guard = running_tasks.lock().await; + tasks_guard.push_back(RunningPrefetchTask { + handle: prefetch_task, + range: next_range.clone(), + }); + } + + debug!( + "Submitted prefetch task for partition {} range {}..{}", + partition_id, next_range.start, next_range.end + ); + + Ok(()) + } + + /// Get data from buffer or fallback to direct read + pub async fn get_data_with_fallback( + &self, + partition_id: u64, + start_row: usize, + end_row: usize, + ) -> Result { + if let Some(state) = self.partition_states.get(&partition_id) { + // First try to get from buffer + { + let mut buffer_guard = state.buffer.lock().await; + + // Remove outdated entries from the front + while let Some(entry) = buffer_guard.front() { + if entry.end_row <= start_row { + buffer_guard.pop_front(); + } else { + break; + } + } + + // Check if we have suitable data in buffer + if let Some(entry) = buffer_guard.front() { + if entry.start_row <= start_row && entry.end_row >= end_row { + // Found matching data, extract it + let entry = buffer_guard.pop_front().unwrap(); + drop(buffer_guard); + + let slice_start = start_row - entry.start_row; + let slice_len = end_row - start_row; + + debug!( + "Using buffered data for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + return Ok(entry.batch.slice(slice_start, slice_len)); + } + } + } + + // Fallback to direct read + debug!( + "Direct read fallback for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + state.reader.read_range(start_row..end_row, None).await + } else { + Err(Error::Internal { + message: format!("Partition {} not found in prefetch manager", partition_id), + location: location!(), + }) + } + } +} + +/// Simplified partition iterator with immediate loading since all partitions need to be accessed +struct PartitionIterator { + reader: Arc, + current_batch: Option, + current_position: usize, + rows_read: usize, + partition_id: u64, + batch_size: u64, +} + +impl PartitionIterator { + async fn new( + store: Arc, + page_file_name: String, + partition_id: u64, + batch_size: u64, + ) -> Result { + let reader = store.open_index_file(&page_file_name).await?; + Ok(Self { + reader, + current_batch: None, + current_position: 0, + rows_read: 0, + partition_id, + batch_size, + }) + } + + /// Get the next element, working with the prefetch manager + async fn next( + &mut self, + prefetch_manager: &PrefetchManager, + ) -> Result> { + // Load new batch if current one is exhausted + if self.needs_new_batch() { + if self.rows_read >= self.reader.num_rows() { + return Ok(None); + } + self.load_next_batch(prefetch_manager).await?; + + // Submit next prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } else { + // Check if we've read half of the current batch, submit next prefetch task + let batch_half = self.current_batch.as_ref().unwrap().num_rows() / 2; + if self.current_position == batch_half && batch_half > 0 { + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } + } + + // Extract next value from current batch + if let Some(batch) = &self.current_batch { + let value = ScalarValue::try_from_array(batch.column(0), self.current_position)?; + let row_id = ScalarValue::try_from_array(batch.column(1), self.current_position)?; + self.current_position += 1; + self.rows_read += 1; + Ok(Some((value, row_id))) + } else { + Ok(None) + } + } + + /// Check if we need to load a new batch + fn needs_new_batch(&self) -> bool { + self.current_batch.is_none() + || self.current_position >= self.current_batch.as_ref().unwrap().num_rows() + } + + async fn load_next_batch(&mut self, prefetch_manager: &PrefetchManager) -> Result<()> { + let remaining_rows = self.reader.num_rows() - self.rows_read; + if remaining_rows == 0 { + self.current_batch = None; + return Ok(()); + } + + let rows_to_read = std::cmp::min(self.batch_size as usize, remaining_rows); + let end_row = self.rows_read + rows_to_read; + + // Use the new fallback mechanism - try buffer first, then direct read + let batch = prefetch_manager + .get_data_with_fallback(self.partition_id, self.rows_read, end_row) + .await?; + + self.current_batch = Some(batch); + self.current_position = 0; + + Ok(()) + } + + fn get_reader(&self) -> Arc { + self.reader.clone() + } +} + +/// Heap elements, used for priority queues in multi-way merging +#[derive(Debug)] +struct HeapElement { + value: ScalarValue, + row_id: ScalarValue, + partition_id: u64, +} + +impl PartialEq for HeapElement { + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } +} + +impl Eq for HeapElement {} + +impl PartialOrd for HeapElement { + fn partial_cmp(&self, other: &Self) -> Option { + // Note: BinaryHeap is a maximum heap, we need a minimum heap, + // so reverse the comparison result + other.value.partial_cmp(&self.value) + } +} + +impl Ord for HeapElement { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap_or(Ordering::Equal) + } +} + +async fn merge_page( + part_lookup_files: &[String], + page_files_map: &HashMap, + store: &Arc, + batch_size: u64, + page_file: &mut Box, + arrow_schema: Arc, + prefetch_config: PrefetchConfig, +) -> Result> { + let mut lookup_entries = Vec::new(); + let mut page_idx = 0u32; + + debug!( + "Starting multi-way merge with {} partitions using prefetch manager", + part_lookup_files.len() + ); + + // Create prefetch manager + let mut prefetch_manager = PrefetchManager::new(prefetch_config.clone()); + + // Directly create iterators and read first element + let mut partition_map = HashMap::new(); + let mut heap = BinaryHeap::new(); + + debug!("Initializing {} partitions", part_lookup_files.len()); + + // Initialize all partitions + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + let page_file_name = page_files_map + .get(&partition_id) + .ok_or_else(|| Error::Internal { + message: format!("Page file not found for partition ID: {}", partition_id), + location: location!(), + })? + .to_string(); + + let mut iterator = + PartitionIterator::new(store.clone(), page_file_name, partition_id, batch_size).await?; + + // Initialize partition in prefetch manager + let reader = iterator.get_reader(); + prefetch_manager.initialize_partition(partition_id, reader); + + // Submit initial prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(partition_id, batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + partition_id, err + ); + } + + let first_element = iterator.next(&prefetch_manager).await?; + + if let Some((value, row_id)) = first_element { + // Put the first element into the heap + heap.push(HeapElement { + value, + row_id, + partition_id, + }); + } + + partition_map.insert(partition_id, iterator); + } + + debug!( + "Initialized {} partitions, heap size: {}", + partition_map.len(), + heap.len() + ); + + let mut current_batch_rows = Vec::with_capacity(batch_size as usize); + let mut total_merged = 0usize; + + // Multi-way merge main loop + while let Some(min_element) = heap.pop() { + // Add current minimum element to batch + current_batch_rows.push((min_element.value, min_element.row_id)); + total_merged += 1; + + // Read next element from corresponding partition + if let Some(iterator) = partition_map.get_mut(&min_element.partition_id) { + if let Some((next_value, next_row_id)) = iterator.next(&prefetch_manager).await? { + heap.push(HeapElement { + value: next_value, + row_id: next_row_id, + partition_id: min_element.partition_id, + }); + } + } + + // Write when batch reaches specified size + if current_batch_rows.len() >= batch_size as usize { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + } + + // Write the remaining data + if !current_batch_rows.is_empty() { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + + debug!( + "Completed multi-way merge: merged {} rows into {} lookup entries", + total_merged, + lookup_entries.len() + ); + Ok(lookup_entries) +} + +/// Helper function to prepare batch data in parallel +async fn prepare_batch_data( + batch_rows: Vec<(ScalarValue, ScalarValue)>, + arrow_schema: Arc, + page_idx: u32, +) -> Result<(RecordBatch, (ScalarValue, ScalarValue, u32, u32))> { + if batch_rows.is_empty() { + return Err(Error::Internal { + message: "Cannot prepare empty batch".to_string(), + location: location!(), + }); + } + + // Parallelize data preparation + let (values, row_ids): (Vec<_>, Vec<_>) = batch_rows.into_iter().unzip(); + + // Convert to arrays in parallel using rayon or manually spawn tasks + let values_array = ScalarValue::iter_to_array(values.into_iter())?; + let row_ids_array = ScalarValue::iter_to_array(row_ids.into_iter())?; + + let batch = RecordBatch::try_new(arrow_schema, vec![values_array, row_ids_array])?; + + // Calculate min/max/null_count for lookup entry + let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; + let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; + let null_count = batch.column(0).null_count() as u32; + + let lookup_entry = (min_val, max_val, null_count, page_idx); + + Ok((batch, lookup_entry)) +} + +/// Helper function to write a batch and create lookup entry +async fn write_batch_and_lookup_entry( + batch_rows: &mut Vec<(ScalarValue, ScalarValue)>, + page_file: &mut Box, + arrow_schema: &Arc, + lookup_entries: &mut Vec<(ScalarValue, ScalarValue, u32, u32)>, + page_idx: &mut u32, +) -> Result<()> { + if batch_rows.is_empty() { + return Ok(()); + } + + // Take ownership of the batch data + let batch_data = std::mem::take(batch_rows); + let current_page_idx = *page_idx; + + // Prepare batch data + let (batch, lookup_entry) = + prepare_batch_data(batch_data, arrow_schema.clone(), current_page_idx).await?; + + lookup_entries.push(lookup_entry); + page_file.write_record_batch(batch).await?; + *page_idx += 1; + + Ok(()) +} + +pub(crate) fn part_page_data_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_PAGES_NAME) +} + +pub(crate) fn part_lookup_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME) +} + +/// A stream that reads the original training data back out of the index +/// +/// This is used for updating the index +struct IndexReaderStream { + reader: Arc, + batch_size: u64, + num_batches: u32, + batch_idx: u32, +} + +impl IndexReaderStream { + async fn new(reader: Arc, batch_size: u64) -> Self { + let num_batches = reader.num_batches(batch_size).await; + Self { + reader, + batch_size, + num_batches, + batch_idx: 0, + } + } +} + +impl Stream for IndexReaderStream { + type Item = BoxFuture<'static, Result>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + if this.batch_idx >= this.num_batches { + return std::task::Poll::Ready(None); + } + let batch_num = this.batch_idx; + this.batch_idx += 1; + let reader_copy = this.reader.clone(); + let batch_size = this.batch_size; + let read_task = async move { + reader_copy + .read_record_batch(batch_num as u64, batch_size) + .await + } + .boxed(); + std::task::Poll::Ready(Some(read_task)) + } +} + +/// Parameters for a btree index +#[derive(Debug, Serialize, Deserialize)] +pub struct BTreeParameters { + /// The number of rows to include in each zone + pub zone_size: Option, +} + +struct BTreeTrainingRequest { + parameters: BTreeParameters, + criteria: TrainingCriteria, +} + +impl BTreeTrainingRequest { + pub fn new(parameters: BTreeParameters) -> Self { + Self { + parameters, + // BTree indexes need data sorted by the value column + criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + } + } +} + +impl TrainingRequest for BTreeTrainingRequest { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn criteria(&self) -> &TrainingCriteria { + &self.criteria + } +} + +#[derive(Debug, Default)] +pub struct BTreeIndexPlugin; + +#[async_trait] +impl ScalarIndexPlugin for BTreeIndexPlugin { + fn new_training_request( + &self, + params: &str, + field: &Field, + ) -> Result> { + if field.data_type().is_nested() { + return Err(Error::InvalidInput { + source: "A btree index can only be created on a non-nested field.".into(), + location: location!(), + }); + } + + let params = serde_json::from_str::(params)?; + Ok(Box::new(BTreeTrainingRequest::new(params))) + } + + fn provides_exact_answer(&self) -> bool { + true + } + + fn version(&self) -> u32 { + 0 + } + + fn new_query_parser( + &self, + index_name: String, + _index_details: &prost_types::Any, + ) -> Option> { + Some(Box::new(SargableQueryParser::new(index_name, false))) + } + + async fn train_index( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + ) -> Result { + let request = request + .as_any() + .downcast_ref::() + .unwrap(); + let value_type = data + .schema() + .field_with_name(VALUE_COLUMN_NAME)? + .data_type() + .clone(); + let flat_index_trainer = FlatIndexMetadata::new(value_type); + train_btree_index( + data, + &flat_index_trainer, + index_store, + request + .parameters + .zone_size + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), + fragment_ids, + ) + .await?; + Ok(CreatedIndex { + index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), + index_version: BTREE_INDEX_VERSION, + }) + } + + async fn load_index( + &self, + index_store: Arc, + _index_details: &prost_types::Any, + frag_reuse_index: Option>, + cache: LanceCache, + ) -> Result> { + Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + use std::{collections::HashMap, sync::Arc}; + + use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; + use arrow_array::FixedSizeListArray; + use arrow_schema::DataType; + use datafusion::{ + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, + }; + use datafusion_common::{DataFusionError, ScalarValue}; + use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; + use deepsize::DeepSizeOf; + use futures::TryStreamExt; + use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; + use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; + use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use tempfile::tempdir; + + use crate::metrics::LocalMetricsCollector; + use crate::{ + metrics::NoOpMetricsCollector, + scalar::{ + btree::{BTreeIndex, BTREE_PAGES_NAME}, + flat::FlatIndexMetadata, + lance_format::LanceIndexStore, + IndexStore, SargableQuery, ScalarIndex, SearchResult, + }, + }; + + use super::{ + part_lookup_file_path, part_page_data_file_path, train_btree_index, OrderableScalarValue, + DEFAULT_BTREE_BATCH_SIZE, + }; + + #[test] + fn test_scalar_value_size() { + let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); + let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(0); 128])], + 128, + ), + ))) + .deep_size_of(); + + // deep_size_of should account for the rust type overhead + assert!(size_of_i32 > 4); + assert!(size_of_many_i32 > 128 * 4); + } + + #[tokio::test] + async fn test_null_ids() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Generate 50,000 rows of random data with 80% nulls + let stream = gen_batch() + .col( + "value", + array::rand::().with_nulls(&[true, false, false, false, false]), + ) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(5000), BatchCount::from(10)); + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(index.page_lookup.null_pages.len(), 10); + + let remap_dir = Arc::new(tempdir().unwrap()); + let remap_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(remap_dir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Remap with a no-op mapping. The remapped index should be identical to the original + index + .remap(&HashMap::default(), remap_store.as_ref()) + .await + .unwrap(); + + let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(remap_index.page_lookup, index.page_lookup); + + let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + + assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + + let original_data = original_pages + .read_record_batch(0, original_pages.num_rows() as u64) + .await + .unwrap(); + let remapped_data = remapped_pages + .read_record_batch(0, remapped_pages.num_rows() as u64) + .await + .unwrap(); + + assert_eq!(original_data, remapped_data); + } + + #[tokio::test] + async fn test_nan_ordering() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let values = vec![ + 0.0, + 1.0, + 2.0, + 3.0, + f64::NAN, + f64::NEG_INFINITY, + f64::INFINITY, + ]; + + // This is a bit overkill but we've had bugs in the past where DF's sort + // didn't agree with Arrow's sort so we do an end-to-end test here + // and use DF to sort the data like we would in a real dataset. + let data = gen_batch() + .col("value", array::cycle::(values.clone())) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(10), BatchCount::from(100)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + .await + .unwrap(); + + for (idx, value) in values.into_iter().enumerate() { + let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert_eq!( + result, + SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) + ); + } + } + + #[tokio::test] + async fn test_page_cache() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let data = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(1000), BatchCount::from(10)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load( + test_store, + None, + LanceCache::with_capacity(100 * 1024 * 1024), + ) + .await + .unwrap(); + + let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); + let metrics = LocalMetricsCollector::default(); + let query1 = index.search(&query, &metrics); + let query2 = index.search(&query, &metrics); + tokio::join!(query1, query2).0.unwrap(); + assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + } + + /// Test that fragment-based btree index construction produces exactly the same results as building a complete index + #[tokio::test] + async fn test_fragment_btree_index_consistency() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Method 1: Build complete index directly using the same data + // Create deterministic data for comparison - use 2 * DEFAULT_BTREE_BATCH_SIZE for testing + let total_count = (2 * DEFAULT_BTREE_BATCH_SIZE) as u64; + let full_data_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(total_count / 2), BatchCount::from(2)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using the same data split into fragments + // Create fragment 1 index - first half of the data (0 to DEFAULT_BTREE_BATCH_SIZE-1) + let half_count = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), // fragment_id = 1 + ) + .await + .unwrap(); + + // Create fragment 2 index - second half of the data (DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1) + let start_val = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_second_half: Vec = (start_val..end_val).collect(); + let row_ids_second_half: Vec = (start_val as u64..end_val as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_second_half)) + .col("_rowid", array::cycle::(row_ids_second_half)) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), // fragment_id = 2 + ) + .await + .unwrap(); + + // Merge the fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + // Test queries one by one to identify the exact problem + + // Test 1: Query for value 0 (should be in first page) + let query_0 = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_0 = full_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_0 = merged_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(full_result_0, merged_result_0, "Query for value 0 failed"); + + // Test 2: Query for value in middle of first batch (should be in first page) + let mid_first_batch = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_first = SargableQuery::Equals(ScalarValue::Int32(Some(mid_first_batch))); + let full_result_mid_first = full_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_first = merged_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_first, merged_result_mid_first, + "Query for value {} failed", + mid_first_batch + ); + + // Test 3: Query for first value in second batch (should be in second page) + let first_second_batch = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_first_second = + SargableQuery::Equals(ScalarValue::Int32(Some(first_second_batch))); + let full_result_first_second = full_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_first_second = merged_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_first_second, merged_result_first_second, + "Query for value {} failed", + first_second_batch + ); + + // Test 4: Query for value in middle of second batch (should be in second page) + let mid_second_batch = (DEFAULT_BTREE_BATCH_SIZE + DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_second = SargableQuery::Equals(ScalarValue::Int32(Some(mid_second_batch))); + + let full_result_mid_second = full_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_second = merged_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_second, merged_result_mid_second, + "Query for value {} failed", + mid_second_batch + ); } #[tokio::test] - async fn test_null_ids() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + async fn test_fragment_btree_index_boundary_queries() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - // Generate 50,000 rows of random data with 80% nulls - let stream = gen_batch() - .col( - "value", - array::rand::().with_nulls(&[true, false, false, false, false]), - ) + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Use 3 * DEFAULT_BTREE_BATCH_SIZE for more comprehensive boundary testing + let total_count = (3 * DEFAULT_BTREE_BATCH_SIZE) as u64; + + // Method 1: Build complete index directly + let full_data_gen = gen_batch() + .col("value", array::step::()) .col("_rowid", array::step::()) - .into_df_stream(RowCount::from(5000), BatchCount::from(10)); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + .into_df_stream(RowCount::from(total_count / 3), BatchCount::from(3)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using 3 fragments + // Fragment 1: 0 to DEFAULT_BTREE_BATCH_SIZE-1 + let fragment_size = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), + ) + .await + .unwrap(); + + // Fragment 2: DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val2 = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val2 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment2: Vec = (start_val2..end_val2).collect(); + let row_ids_fragment2: Vec = (start_val2 as u64..end_val2 as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_fragment2)) + .col("_rowid", array::cycle::(row_ids_fragment2)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), + ) + .await + .unwrap(); - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000) + // Fragment 3: 2*DEFAULT_BTREE_BATCH_SIZE to 3*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val3 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let end_val3 = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment3: Vec = (start_val3..end_val3).collect(); + let row_ids_fragment3: Vec = (start_val3 as u64..end_val3 as u64).collect(); + let fragment3_gen = gen_batch() + .col("value", array::cycle::(values_fragment3)) + .col("_rowid", array::cycle::(row_ids_fragment3)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment3_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment3_gen.schema(), + fragment3_gen, + )); + + train_btree_index( + fragment3_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![3]), + ) + .await + .unwrap(); + + // Merge all fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + part_page_data_file_path(3 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + part_lookup_file_path(3 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - assert_eq!(index.page_lookup.null_pages.len(), 10); + // === Boundary Value Tests === - let remap_dir = Arc::new(tempdir().unwrap()); - let remap_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(remap_dir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 1: Query minimum value (boundary: data start) + let query_min = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_min = full_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_min = merged_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_min, merged_result_min, + "Query for minimum value 0 failed" + ); - // Remap with a no-op mapping. The remapped index should be identical to the original - index - .remap(&HashMap::default(), remap_store.as_ref()) + // Test 2: Query maximum value (boundary: data end) + let max_val = (3 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val))); + let full_result_max = full_index + .search(&query_max, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_max = merged_index + .search(&query_max, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_max, merged_result_max, + "Query for maximum value {} failed", + max_val + ); - let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + // Test 3: Query fragment boundary value (last value of first fragment) + let fragment1_last = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag1_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment1_last))); + let full_result_frag1_last = full_index + .search(&query_frag1_last, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_frag1_last = merged_index + .search(&query_frag1_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag1_last, merged_result_frag1_last, + "Query for fragment 1 last value {} failed", + fragment1_last + ); - assert_eq!(remap_index.page_lookup, index.page_lookup); + // Test 4: Query fragment boundary value (first value of second fragment) + let fragment2_first = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_frag2_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_first))); + let full_result_frag2_first = full_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_first = merged_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_first, merged_result_frag2_first, + "Query for fragment 2 first value {} failed", + fragment2_first + ); - let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); - let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + // Test 5: Query fragment boundary value (last value of second fragment) + let fragment2_last = (2 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag2_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_last))); + let full_result_frag2_last = full_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_last = merged_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_last, merged_result_frag2_last, + "Query for fragment 2 last value {} failed", + fragment2_last + ); - assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + // Test 6: Query fragment boundary value (first value of third fragment) + let fragment3_first = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_frag3_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment3_first))); + let full_result_frag3_first = full_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag3_first = merged_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag3_first, merged_result_frag3_first, + "Query for fragment 3 first value {} failed", + fragment3_first + ); - let original_data = original_pages - .read_record_batch(0, original_pages.num_rows() as u64) + // === Non-existent Value Tests === + + // Test 7: Query value below minimum + let query_below_min = SargableQuery::Equals(ScalarValue::Int32(Some(-1))); + let full_result_below = full_index + .search(&query_below_min, &NoOpMetricsCollector) .await .unwrap(); - let remapped_data = remapped_pages - .read_record_batch(0, remapped_pages.num_rows() as u64) + let merged_result_below = merged_index + .search(&query_below_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_below, merged_result_below, + "Query for value below minimum (-1) failed" + ); + + // Test 8: Query value above maximum + let query_above_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val + 1))); + let full_result_above = full_index + .search(&query_above_max, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_above = merged_index + .search(&query_above_max, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_above, + merged_result_above, + "Query for value above maximum ({}) failed", + max_val + 1 + ); - assert_eq!(original_data, remapped_data); - } + // === Range Query Tests === - #[tokio::test] - async fn test_nan_ordering() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 9: Cross-fragment range query (from first fragment to second fragment) + let range_start = (DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let range_end = (DEFAULT_BTREE_BATCH_SIZE + 100) as i32; + let query_cross_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(range_end))), + ); + let full_result_cross = full_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_cross = merged_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_cross, merged_result_cross, + "Cross-fragment range query [{}, {}] failed", + range_start, range_end + ); - let values = vec![ - 0.0, - 1.0, - 2.0, - 3.0, - f64::NAN, - f64::NEG_INFINITY, - f64::INFINITY, - ]; + // Test 10: Range query within single fragment + let single_frag_start = 100i32; + let single_frag_end = 200i32; + let query_single_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(single_frag_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(single_frag_end))), + ); + let full_result_single = full_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_single = merged_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_single, merged_result_single, + "Single fragment range query [{}, {}] failed", + single_frag_start, single_frag_end + ); - // This is a bit overkill but we've had bugs in the past where DF's sort - // didn't agree with Arrow's sort so we do an end-to-end test here - // and use DF to sort the data like we would in a real dataset. - let data = gen_batch() - .col("value", array::cycle::(values.clone())) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(10), BatchCount::from(100)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + // Test 11: Large range query spanning all fragments + let large_range_start = 100i32; + let large_range_end = (3 * DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let query_large_range = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(large_range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(large_range_end))), + ); + let full_result_large = full_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_large = merged_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_large, merged_result_large, + "Large range query [{}, {}] failed", + large_range_start, large_range_end + ); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + // === Range Boundary Query Tests === - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) + // Test 12: Less than query (implemented using range query, from minimum to specified value) + let lt_val = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_lt = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(lt_val))), + ); + let full_result_lt = full_index + .search(&query_lt, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_lt = merged_index + .search(&query_lt, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lt, merged_result_lt, + "Less than query (<{}) failed", + lt_val + ); - let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + // Test 13: Greater than query (implemented using range query, from specified value to maximum) + let gt_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let max_range_val = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gt = SargableQuery::Range( + std::collections::Bound::Excluded(ScalarValue::Int32(Some(gt_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gt = full_index + .search(&query_gt, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gt = merged_index + .search(&query_gt, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_gt, merged_result_gt, + "Greater than query (>{}) failed", + gt_val + ); - for (idx, value) in values.into_iter().enumerate() { - let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); - let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) - ); - } + // Test 14: Less than or equal query (implemented using range query, including boundary value) + let lte_val = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_lte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Included(ScalarValue::Int32(Some(lte_val))), + ); + let full_result_lte = full_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_lte = merged_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lte, merged_result_lte, + "Less than or equal query (<={}) failed", + lte_val + ); + + // Test 15: Greater than or equal query (implemented using range query, including boundary value) + let gte_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(gte_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gte = full_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gte = merged_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_gte, merged_result_gte, + "Greater than or equal query (>={}) failed", + gte_val + ); + } + + #[test] + fn test_extract_partition_id() { + // Test valid partition file names + assert_eq!( + super::extract_partition_id("part_123_page_data.lance").unwrap(), + 123 + ); + assert_eq!( + super::extract_partition_id("part_456_page_lookup.lance").unwrap(), + 456 + ); + assert_eq!( + super::extract_partition_id("part_4294967296_page_data.lance").unwrap(), + 4294967296 + ); + + // Test invalid file names + assert!(super::extract_partition_id("invalid_filename.lance").is_err()); + assert!(super::extract_partition_id("part_abc_page_data.lance").is_err()); + assert!(super::extract_partition_id("part_123").is_err()); + assert!(super::extract_partition_id("part_").is_err()); } #[tokio::test] - async fn test_page_cache() { + async fn test_cleanup_partition_files() { + use crate::scalar::lance_format::LanceIndexStore; + use lance_core::cache::LanceCache; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use std::sync::Arc; + use tempfile::tempdir; + + // Create a test store let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + let test_store: Arc = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), Path::from_filesystem_path(tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - let data = gen_batch() - .col("value", array::step::()) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(1000), BatchCount::from(10)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + // Test files with different patterns + let lookup_files = vec![ + "part_123_page_lookup.lance".to_string(), + "invalid_lookup_file.lance".to_string(), + "part_456_page_lookup.lance".to_string(), + ]; - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) - .await - .unwrap(); + let page_files = vec![ + "part_123_page_data.lance".to_string(), + "invalid_page_file.lance".to_string(), + "part_456_page_data.lance".to_string(), + ]; - let index = BTreeIndex::load( - test_store, - None, - LanceCache::with_capacity(100 * 1024 * 1024), - ) - .await - .unwrap(); + // The cleanup function should handle both valid and invalid file patterns gracefully + // This test mainly verifies that the function doesn't panic and handles edge cases + super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; - let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); - let metrics = LocalMetricsCollector::default(); - let query1 = index.search(&query, &metrics); - let query2 = index.search(&query, &metrics); - tokio::join!(query1, query2).0.unwrap(); - assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + // If we get here without panicking, the cleanup function handled all cases correctly + assert!(true); } } diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 4edb1cb6a0a..a4506020782 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -163,6 +163,7 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -170,7 +171,8 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { source: "must provide training request created by new_training_request".into(), location: location!(), })?; - Self::train_inverted_index(data, index_store, request.parameters.clone(), None).await + Self::train_inverted_index(data, index_store, request.parameters.clone(), fragment_ids) + .await } /// Load an index from storage diff --git a/rust/lance-index/src/scalar/json.rs b/rust/lance-index/src/scalar/json.rs index 0b8a43efbe7..e36feaacfc7 100644 --- a/rust/lance-index/src/scalar/json.rs +++ b/rust/lance-index/src/scalar/json.rs @@ -768,6 +768,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -797,7 +798,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { )?; let target_index = target_plugin - .train_index(converted_stream, index_store, target_request) + .train_index(converted_stream, index_store, target_request, fragment_ids) .await?; let index_details = crate::pb::JsonIndexDetails { diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index 542aa2bc97a..64e932c47c5 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -398,6 +398,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let schema = data.schema(); let field = schema @@ -427,7 +428,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { let data = unnest_chunks(data)?; let bitmap_plugin = BitmapIndexPlugin; bitmap_plugin - .train_index(data, index_store, request) + .train_index(data, index_store, request, fragment_ids) .await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pb::LabelListIndexDetails::default()) diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index d8a95de1eeb..4df502ead09 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -368,7 +368,7 @@ pub mod tests { ) .unwrap(); btree_plugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -866,6 +866,7 @@ pub mod tests { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); @@ -911,7 +912,7 @@ pub mod tests { .new_training_request("{}", &Field::new(VALUE_COLUMN_NAME, DataType::Int32, false)) .unwrap(); BitmapIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -1399,7 +1400,7 @@ pub mod tests { ) .unwrap(); LabelListIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index ff559dd9292..586b0a4da9a 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -1285,6 +1285,7 @@ impl ScalarIndexPlugin for NGramIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_ngram_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 022da729f0c..3880aad4dbe 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -119,6 +119,7 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result; /// Returns true if the index returns an exact answer (e.g. not AtMost) diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 748ab003863..c9097cbb8cc 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -961,6 +961,7 @@ impl ScalarIndexPlugin for ZoneMapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + _fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 0742ff7f878..58b94f56318 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -71,6 +71,7 @@ impl BenchmarkFixture { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index cdebc399547..ccae72a4865 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -284,12 +284,11 @@ pub(super) async fn build_scalar_index( training_request.criteria(), None, train, - fragment_ids, + fragment_ids.clone(), ) .await?; - plugin - .train_index(training_data, &index_store, training_request) + .train_index(training_data, &index_store, training_request, fragment_ids) .await } From acbb5a698cd8c9c53182db83d67b7a55d9848729 Mon Sep 17 00:00:00 2001 From: xloya Date: Fri, 5 Sep 2025 12:48:21 +0800 Subject: [PATCH 05/13] support btree distributely --- java/.gitignore | 3 + java/core/.classpath | 50 + java/core/.project | 23 + .../org.eclipse.core.resources.prefs | 5 + .../.settings/org.eclipse.jdt.apt.core.prefs | 2 + .../core/.settings/org.eclipse.jdt.core.prefs | 9 + .../core/.settings/org.eclipse.m2e.core.prefs | 4 + python/python/lance/dataset.py | 36 +- python/python/lance/lance/__init__.pyi | 4 +- python/python/tests/test_scalar_index.py | 416 +++- python/src/dataset.rs | 126 +- rust/lance-index/src/scalar/bitmap.rs | 1 + rust/lance-index/src/scalar/btree.rs | 2185 ++++++++++++++--- rust/lance-index/src/scalar/inverted.rs | 4 +- rust/lance-index/src/scalar/json.rs | 3 +- rust/lance-index/src/scalar/label_list.rs | 3 +- rust/lance-index/src/scalar/lance_format.rs | 7 +- rust/lance-index/src/scalar/ngram.rs | 1 + rust/lance-index/src/scalar/registry.rs | 1 + rust/lance-index/src/scalar/zonemap.rs | 1 + rust/lance/benches/scalar_index.rs | 1 + rust/lance/src/index/scalar.rs | 5 +- 22 files changed, 2555 insertions(+), 335 deletions(-) create mode 100644 java/core/.classpath create mode 100644 java/core/.project create mode 100644 java/core/.settings/org.eclipse.core.resources.prefs create mode 100644 java/core/.settings/org.eclipse.jdt.apt.core.prefs create mode 100644 java/core/.settings/org.eclipse.jdt.core.prefs create mode 100644 java/core/.settings/org.eclipse.m2e.core.prefs diff --git a/java/.gitignore b/java/.gitignore index d9074bd2835..228f8fb5922 100644 --- a/java/.gitignore +++ b/java/.gitignore @@ -1,2 +1,5 @@ *.iml .java-version +.classpath +.project +.settings \ No newline at end of file diff --git a/java/core/.classpath b/java/core/.classpath new file mode 100644 index 00000000000..5c8072ecc61 --- /dev/null +++ b/java/core/.classpath @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java/core/.project b/java/core/.project new file mode 100644 index 00000000000..4a9eedb6505 --- /dev/null +++ b/java/core/.project @@ -0,0 +1,23 @@ + + + lance-core + + + + + + org.eclipse.jdt.core.javabuilder + + + + + org.eclipse.m2e.core.maven2Builder + + + + + + org.eclipse.jdt.core.javanature + org.eclipse.m2e.core.maven2Nature + + diff --git a/java/core/.settings/org.eclipse.core.resources.prefs b/java/core/.settings/org.eclipse.core.resources.prefs new file mode 100644 index 00000000000..cdfe4f1b669 --- /dev/null +++ b/java/core/.settings/org.eclipse.core.resources.prefs @@ -0,0 +1,5 @@ +eclipse.preferences.version=1 +encoding//src/main/java=UTF-8 +encoding//src/test/java=UTF-8 +encoding//src/test/resources=UTF-8 +encoding/=UTF-8 diff --git a/java/core/.settings/org.eclipse.jdt.apt.core.prefs b/java/core/.settings/org.eclipse.jdt.apt.core.prefs new file mode 100644 index 00000000000..d4313d4b25e --- /dev/null +++ b/java/core/.settings/org.eclipse.jdt.apt.core.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.apt.aptEnabled=false diff --git a/java/core/.settings/org.eclipse.jdt.core.prefs b/java/core/.settings/org.eclipse.jdt.core.prefs new file mode 100644 index 00000000000..1b6e1ef22f9 --- /dev/null +++ b/java/core/.settings/org.eclipse.jdt.core.prefs @@ -0,0 +1,9 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 +org.eclipse.jdt.core.compiler.compliance=1.8 +org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled +org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning +org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore +org.eclipse.jdt.core.compiler.processAnnotations=disabled +org.eclipse.jdt.core.compiler.release=disabled +org.eclipse.jdt.core.compiler.source=1.8 diff --git a/java/core/.settings/org.eclipse.m2e.core.prefs b/java/core/.settings/org.eclipse.m2e.core.prefs new file mode 100644 index 00000000000..f897a7f1cb2 --- /dev/null +++ b/java/core/.settings/org.eclipse.m2e.core.prefs @@ -0,0 +1,4 @@ +activeProfiles= +eclipse.preferences.version=1 +resolveWorkspaceProjects=true +version=1 diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 376476e43af..54bd432201a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2731,8 +2731,40 @@ def prewarm_index(self, name: str): """ return self._ds.prewarm_index(name) - def merge_index_metadata(self, index_uuid: str): - return self._ds.merge_index_metadata(index_uuid) + def merge_index_metadata( + self, + index_uuid: str, + index_type: Union[ + Literal["BTREE"], + Literal["INVERTED"], + ], + prefetch_batch: Optional[int] = None, + ): + """ + Merge an index which not commit at present. + + Parameters + ---------- + index_uuid: str + The uuid of the index which want to merge. + index_type: Literal["BTREE", "INVERTED"] + The type of the index. + prefetch_batch: int, optional + The number of prefetch batches of sub-page files for merging. + Default 1. + """ + index_type = index_type.upper() + if index_type not in [ + "BTREE", + "INVERTED", + ]: + raise NotImplementedError( + ( + 'Only "BTREE" or "INVERTED" are supported for ' + f"merge index metadata. Received {index_type}", + ) + ) + return self._ds.merge_index_metadata(index_uuid, index_type, prefetch_batch) def session(self) -> Session: """ diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index c2a72b7b1b5..0bae8e2f1aa 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -282,7 +282,9 @@ class _Dataset: ): ... def drop_index(self, name: str): ... def prewarm_index(self, name: str): ... - def merge_index_metadata(self, index_uuid: str): ... + def merge_index_metadata( + self, index_uuid: str, index_type: str, prefetch_batch: Optional[int] = None + ): ... def count_fragments(self) -> int: ... def num_small_files(self, max_rows_per_group: int) -> int: ... def get_fragments(self) -> List[_Fragment]: ... diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index c2370a17a9e..5f92a1e4d11 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -1982,7 +1982,7 @@ def build_distributed_fts_index( ) # Merge the inverted index metadata - dataset.merge_index_metadata(index_id) + dataset.merge_index_metadata(index_id, index_type="INVERTED") # Create Index object for commit field_id = dataset.schema.get_field_index(column) @@ -2856,7 +2856,7 @@ def test_distribute_fts_index_build(tmp_path): print(f"Fragment {fragment_id} index created successfully") # Merge the inverted index metadata - ds.merge_index_metadata(index_id) + ds.merge_index_metadata(index_id, index_type="INVERTED") # Create an Index object using the new dataclass format from lance.dataset import Index @@ -2983,3 +2983,415 @@ def test_backward_compatibility_no_fragment_ids(tmp_path): results = ds.scanner(full_text_query=search_word).to_table() assert results.num_rows > 0 + + +def test_distribute_btree_index_build(tmp_path): + """ + Test distributed B-tree index build similar to test_distribute_fts_index_build. + This test creates B-tree indices on individual fragments and then + commits them as a single index. + """ + # Generate test dataset with multiple fragments + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=4, rows_per_fragment=10000 + ) + + import uuid + + index_id = str(uuid.uuid4()) + print(f"Using index ID: {index_id}") + index_name = "btree_multiple_fragment_idx" + + fragments = ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + for fragment in ds.get_fragments(): + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + # Create B-tree scalar index for each fragment + # Use the same index_name for all fragments (like in FTS test) + ds.create_scalar_index( + column="id", # Use integer column for B-tree + index_type="BTREE", + name=index_name, + replace=False, + fragment_uuid=index_id, + fragment_ids=[fragment_id], + ) + + # For fragment-level indexing, we expect the method to return successfully + # but not commit the index yet + print(f"Fragment {fragment_id} B-tree index created successfully") + + # Merge the B-tree index metadata + ds.merge_index_metadata(index_id, index_type="BTREE") + print(ds.uri) + + # Create an Index object using the new dataclass format + from lance.dataset import Index + + # Get the schema field for the indexed column + field_id = ds.schema.get_field_index("id") + + index = Index( + uuid=index_id, + name=index_name, + fields=[field_id], # Use field index instead of field object + dataset_version=ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Create the index operation + create_index_op = lance.LanceOperation.CreateIndex( + new_indices=[index], + removed_indices=[], + ) + + # Commit the index + ds_committed = lance.LanceDataset.commit( + ds.uri, + create_index_op, + read_version=ds.version, + ) + + print("Successfully committed multiple fragment B-tree index") + + # Verify the index was created and is functional + indices = ds_committed.list_indices() + assert len(indices) > 0, "No indices found after commit" + + # Find our index + our_index = None + for idx in indices: + if idx["name"] == index_name: + our_index = idx + break + + assert our_index is not None, f"Index '{index_name}' not found in indices list" + assert our_index["type"] == "BTree", ( + f"Expected BTree index, got {our_index['type']}" + ) + + # Test that the index works for searching + # Test exact equality queries + test_id = 100 # Should be in first fragment + results = ds_committed.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + print(f"Search for id = {test_id} returned {results.num_rows} results") + assert results.num_rows > 0, f"No results found for id = {test_id}" + + # Test range queries across fragments + results_range = ds_committed.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + print(f"Range query returned {results_range.num_rows} results") + assert results_range.num_rows > 0, "No results found for range query" + + # Compare with complete index results to ensure consistency + # Create a reference dataset with complete index + reference_ds = generate_multi_fragment_dataset( + tmp_path / "reference", num_fragments=4, rows_per_fragment=10000 + ) + + # Create complete B-tree index for comparison + reference_ds.create_scalar_index( + column="id", + index_type="BTREE", + name="reference_btree_idx", + ) + + # Compare exact query results + reference_results = reference_ds.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + assert results.num_rows == reference_results.num_rows, ( + f"Distributed index returned {results.num_rows} results, " + f"but complete index returned {reference_results.num_rows} results" + ) + + # Compare range query results + reference_range_results = reference_ds.scanner( + filter="id >= 200 AND id < 800", + columns=["id", "text"], + ).to_table() + + assert results_range.num_rows == reference_range_results.num_rows, ( + f"Distributed index range query returned {results_range.num_rows} results, " + f"but complete index returned {reference_range_results.num_rows} results" + ) + + +def test_btree_precise_query_comparison(tmp_path): + """ + Precise comparison test between fragment-level B-tree index and complete + B-tree index. + This test creates identical datasets and compares query results in detail. + """ + # Test configuration + num_fragments = 3 + rows_per_fragment = 10000 + total_rows = num_fragments * rows_per_fragment + + print( + f"Creating datasets with {num_fragments} fragments," + f" {rows_per_fragment} rows each" + ) + + # Create dataset for fragment-level indexing + fragment_ds = generate_multi_fragment_dataset( + tmp_path / "fragment", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + # Create dataset for complete indexing (same data structure) + complete_ds = generate_multi_fragment_dataset( + tmp_path / "complete", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + import uuid + + # Build fragment-level B-tree index + fragment_index_id = str(uuid.uuid4()) + fragment_index_name = "fragment_btree_precise_test" + + fragments = fragment_ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + # Create fragment-level indices + for fragment in fragments: + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + fragment_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=fragment_index_name, + replace=False, + fragment_uuid=fragment_index_id, + fragment_ids=[fragment_id], + ) + + # Merge fragment indices + fragment_ds.merge_index_metadata(fragment_index_id, index_type="BTREE") + + # Create Index object for fragment-based index + from lance.dataset import Index + + field_id = fragment_ds.schema.get_field_index("id") + + fragment_index = Index( + uuid=fragment_index_id, + name=fragment_index_name, + fields=[field_id], + dataset_version=fragment_ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Commit fragment-based index + create_fragment_index_op = lance.LanceOperation.CreateIndex( + new_indices=[fragment_index], + removed_indices=[], + ) + + fragment_ds_committed = lance.LanceDataset.commit( + fragment_ds.uri, + create_fragment_index_op, + read_version=fragment_ds.version, + ) + + # Build complete B-tree index + complete_index_name = "complete_btree_precise_test" + complete_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=complete_index_name, + ) + + print("Both indices created successfully") + + # Detailed query comparison tests + test_cases = [ + # Test 1: Boundary values at fragment edges + {"name": "First value", "filter": "id = 0"}, + {"name": "Fragment 0 last value", "filter": f"id = {rows_per_fragment - 1}"}, + {"name": "Fragment 1 first value", "filter": f"id = {rows_per_fragment}"}, + { + "name": "Fragment 1 last value", + "filter": f"id = {2 * rows_per_fragment - 1}", + }, + {"name": "Fragment 2 first value", "filter": f"id = {2 * rows_per_fragment}"}, + {"name": "Last value", "filter": f"id = {total_rows - 1}"}, + # Test 2: Values in the middle of fragments + {"name": "Fragment 0 middle", "filter": f"id = {rows_per_fragment // 2}"}, + { + "name": "Fragment 1 middle", + "filter": f"id = {rows_per_fragment + rows_per_fragment // 2}", + }, + { + "name": "Fragment 2 middle", + "filter": f"id = {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 3: Range queries within single fragments + {"name": "Range within fragment 0", "filter": "id >= 10 AND id < 20"}, + { + "name": "Range within fragment 1", + "filter": f"id >= {rows_per_fragment + 10}" + f" AND id < {rows_per_fragment + 20}", + }, + { + "name": "Range within fragment 2", + "filter": f"id >= {2 * rows_per_fragment + 10}" + f" AND id < {2 * rows_per_fragment + 20}", + }, + # Test 4: Range queries spanning multiple fragments + { + "name": "Cross fragment 0-1", + "filter": f"id >= {rows_per_fragment - 5} AND id < {rows_per_fragment + 5}", + }, + { + "name": "Cross fragment 1-2", + "filter": f"id >= {2 * rows_per_fragment - 5}" + f" AND id < {2 * rows_per_fragment + 5}", + }, + { + "name": "Cross all fragments", + "filter": f"id >= {rows_per_fragment // 2} AND" + f" id < {2 * rows_per_fragment + rows_per_fragment // 2}", + }, + # Test 5: Edge cases + {"name": "Non-existent small value", "filter": "id = -1"}, + {"name": "Non-existent large value", "filter": f"id = {total_rows + 100}"}, + {"name": "Large range", "filter": f"id >= 0 AND id < {total_rows}"}, + # Test 6: Comparison operators + {"name": "Less than boundary", "filter": f"id < {rows_per_fragment}"}, + { + "name": "Greater than boundary", + "filter": f"id > {2 * rows_per_fragment - 1}", + }, + {"name": "Less than or equal", "filter": f"id <= {rows_per_fragment + 50}"}, + {"name": "Greater than or equal", "filter": f"id >= {rows_per_fragment + 50}"}, + ] + + print(f"\nRunning {len(test_cases)} detailed comparison tests:") + + for i, test_case in enumerate(test_cases, 1): + test_name = test_case["name"] + filter_expr = test_case["filter"] + + print(f" {i:2d}. Testing {test_name}: {filter_expr}") + + # Query fragment-based index + fragment_results = fragment_ds_committed.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Query complete index + complete_results = complete_ds.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Compare row counts + assert fragment_results.num_rows == complete_results.num_rows, ( + f"Test '{test_name}' failed: Fragment index " + f"returned {fragment_results.num_rows} rows, " + f"but complete index returned {complete_results.num_rows}" + f" rows for filter: {filter_expr}" + ) + + # Compare actual results if there are any + if fragment_results.num_rows > 0: + # Sort both results by id for comparison + fragment_ids = sorted(fragment_results.column("id").to_pylist()) + complete_ids = sorted(complete_results.column("id").to_pylist()) + + assert fragment_ids == complete_ids, ( + f"Test '{test_name}' failed: Fragment index" + f" returned different IDs than complete index. " + f"Fragment IDs:" + f" {fragment_ids[:10]}{'...' if len(fragment_ids) > 10 else ''}, " + f"Complete IDs:" + f" {complete_ids[:10]}{'...' if len(complete_ids) > 10 else ''}" + ) + + print(f" āœ“ Passed ({fragment_results.num_rows} rows)") + + print(f"\nāœ… All {len(test_cases)} precision tests passed!") + print( + "Fragment-level B-tree index produces identical results" + " to complete B-tree index." + ) + + +def test_btree_fragment_ids_parameter_validation(tmp_path): + """ + Test validation of fragment_ids parameter for B-tree indices. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # Test with valid fragment IDs + fragments = ds.get_fragments() + valid_fragment_id = fragments[0].fragment_id + + # This should work without errors + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[valid_fragment_id], + ) + + # Test with invalid fragment ID (should handle gracefully) + try: + ds.create_scalar_index( + column="id", + index_type="BTREE", + fragment_ids=[999999], # Non-existent fragment ID + ) + except Exception as e: + # It's acceptable for this to fail with an appropriate error + print(f"Expected error for invalid fragment ID: {e}") + + +def test_btree_backward_compatibility_no_fragment_ids(tmp_path): + """ + Test that B-tree indexing remains backward compatible + when fragment_ids is not provided. + """ + ds = generate_multi_fragment_dataset( + tmp_path, num_fragments=2, rows_per_fragment=10000 + ) + + # This should work exactly as before (full dataset indexing) + ds.create_scalar_index( + column="id", + index_type="BTREE", + name="full_dataset_btree_idx", + ) + + # Verify the index was created + indices = ds.list_indices() + assert len(indices) == 1 + assert indices[0]["name"] == "full_dataset_btree_idx" + assert indices[0]["type"] == "BTree" + + # Test that the index works + results = ds.scanner(filter="id = 50").to_table() + assert results.num_rows > 0 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 83da7f26dc9..f3d0fd10e83 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1670,47 +1670,111 @@ impl Dataset { .infer_error() } - #[pyo3(signature = (index_uuid))] - fn merge_index_metadata(&self, index_uuid: &str) -> PyResult<()> { + #[pyo3(signature = (index_uuid, index_type, prefetch_batch))] + fn merge_index_metadata( + &self, + index_uuid: &str, + index_type: &str, + prefetch_batch: Option, + ) -> PyResult<()> { RT.block_on(None, async { + let index_type = index_type.to_uppercase(); + let idx_type = match index_type.as_str() { + "BTREE" => IndexType::BTree, + "INVERTED" => IndexType::Inverted, + _ => { + return Err(Error::InvalidInput { + source: format!( + "Index type {} is not supported.", + index_type + ).into(), + location: location!(), + }); + } + }; + let store = LanceIndexStore::from_dataset_for_new(self.ds.as_ref(), index_uuid)?; let index_dir = self.ds.indices_dir().child(index_uuid); + if idx_type == IndexType::Inverted { + // List all partition metadata files in the index directory + let mut part_metadata_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") + { + part_metadata_files.push(file_name.to_string()); + } + } + Err(_) => continue, + } + } + + if part_metadata_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {}", + index_dir + ) + .into(), + location: location!(), + }); + } - // List all partition metadata files in the index directory - let mut part_metadata_files = Vec::new(); - let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); - - while let Some(item) = list_stream.next().await { - match item { - Ok(meta) => { - let file_name = meta.location.filename().unwrap_or_default(); - // Filter files matching the pattern part_*_metadata.lance - if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") - { - part_metadata_files.push(file_name.to_string()); + // Call merge_metadata_files function for inverted index + lance_index::scalar::inverted::builder::merge_metadata_files( + Arc::new(store), + &part_metadata_files, + ) + .await + } else { + // List all partition page / lookup files in the index directory + let mut part_page_files = Vec::new(); + let mut part_lookup_files = Vec::new(); + let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_page_data.lance") + { + part_page_files.push(file_name.to_string()); + } + if file_name.starts_with("part_") && file_name.ends_with("_page_lookup.lance") + { + part_lookup_files.push(file_name.to_string()); + } } + Err(_) => continue, } - Err(_) => continue, } - } + if part_page_files.is_empty() || part_lookup_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {} (page_files: {}, lookup_files: {})", + index_dir, part_page_files.len(), part_lookup_files.len() + ) + .into(), + location: location!(), + }); + } - if part_metadata_files.is_empty() { - return Err(Error::InvalidInput { - source: format!( - "No partition metadata files found in index directory: {}", - index_dir - ) - .into(), - location: location!(), - }); + // Call merge_metadata_files function for btree index + lance_index::scalar::btree::merge_metadata_files( + Arc::new(store), + &part_page_files, + &part_lookup_files, + prefetch_batch, + ).await } - // Call merge_metadata_files function for inverted index - lance_index::scalar::inverted::builder::merge_metadata_files( - Arc::new(store), - &part_metadata_files, - ) - .await + })? .map_err(|err| PyValueError::new_err(err.to_string())) } diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 1b5d0d530bd..09dde5297b4 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -528,6 +528,7 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_bitmap_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index b2760fe214d..9c3e6685703 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -4,10 +4,10 @@ use std::{ any::Any, cmp::Ordering, - collections::{BTreeMap, BinaryHeap, HashMap}, + collections::{BTreeMap, BinaryHeap, HashMap, VecDeque}, fmt::{Debug, Display}, ops::Bound, - sync::Arc, + sync::{Arc, LazyLock}, }; use super::{ @@ -38,7 +38,7 @@ use deepsize::DeepSizeOf; use futures::{ future::BoxFuture, stream::{self}, - FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, + Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, }; use lance_core::{ cache::{CacheKey, LanceCache}, @@ -54,16 +54,37 @@ use lance_datafusion::{ chunker::chunk_concat_stream, exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; -use log::debug; +use log::{debug, warn}; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize, Serializer}; use snafu::location; +use tokio::runtime::{Builder, Runtime}; use tracing::info; const BTREE_LOOKUP_NAME: &str = "page_lookup.lance"; const BTREE_PAGES_NAME: &str = "page_data.lance"; pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; const BATCH_SIZE_META_KEY: &str = "batch_size"; + +/// Global thread pool for B-tree prefetch operations +static BTREE_PREFETCH_RUNTIME: LazyLock = LazyLock::new(|| { + Builder::new_multi_thread() + .worker_threads(get_num_compute_intensive_cpus()) + .max_blocking_threads(get_num_compute_intensive_cpus()) + .thread_name("lance-btree-prefetch") + .enable_time() + .build() + .expect("Failed to create B-tree prefetch runtime") +}); + +/// Spawn a prefetch task on the B-tree thread pool +fn spawn_btree_prefetch(future: F) -> tokio::task::JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + BTREE_PREFETCH_RUNTIME.spawn(future) +} const BTREE_INDEX_VERSION: u32 = 0; pub(crate) const BTREE_VALUES_COLUMN: &str = "values"; pub(crate) const BTREE_IDS_COLUMN: &str = "ids"; @@ -1231,6 +1252,7 @@ impl ScalarIndex for BTreeIndex { self.sub_index.as_ref(), dest_store, DEFAULT_BTREE_BATCH_SIZE, + None, ) .await?; @@ -1366,10 +1388,33 @@ pub async fn train_btree_index( sub_index_trainer: &dyn BTreeSubIndex, index_store: &dyn IndexStore, batch_size: u64, + fragment_ids: Option>, ) -> Result<()> { - let mut sub_index_file = index_store - .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) - .await?; + let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| { + if !frag_ids.is_empty() { + // Create a mask with fragment_id in high 32 bits for distributed indexing + // This mask is used to filter partitions belonging to specific fragments + // If multiple fragments processed, use first fragment_id <<32 as mask + Some((frag_ids[0] as u64) << 32) + } else { + None + } + }); + + let mut sub_index_file; + if fragment_mask.is_none() { + sub_index_file = index_store + .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) + .await?; + } else { + sub_index_file = index_store + .new_index_file( + part_page_data_file_path(fragment_mask.unwrap()).as_str(), + sub_index_trainer.schema().clone(), + ) + .await?; + } + let mut encoded_batches = Vec::new(); let mut batch_idx = 0; @@ -1393,385 +1438,1945 @@ pub async fn train_btree_index( file_schema .metadata .insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); - let mut btree_index_file = index_store - .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) - .await?; + let mut btree_index_file; + if fragment_mask.is_none() { + btree_index_file = index_store + .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) + .await?; + } else { + btree_index_file = index_store + .new_index_file( + part_lookup_file_path(fragment_mask.unwrap()).as_str(), + Arc::new(file_schema), + ) + .await?; + } btree_index_file.write_record_batch(record_batch).await?; btree_index_file.finish().await?; Ok(()) } -/// A stream that reads the original training data back out of the index -/// -/// This is used for updating the index -struct IndexReaderStream { - reader: Arc, - batch_size: u64, - num_batches: u32, - batch_idx: u32, +/// Extract partition ID from partition file name +/// Expected format: "part_{partition_id}_{suffix}.lance" +fn extract_partition_id(filename: &str) -> Result { + if !filename.starts_with("part_") { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + let parts: Vec<&str> = filename.split('_').collect(); + if parts.len() < 3 { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + parts[1].parse::().map_err(|_| Error::Internal { + message: format!("Failed to parse partition ID from filename: {}", filename), + location: location!(), + }) } -impl IndexReaderStream { - async fn new(reader: Arc, batch_size: u64) -> Self { - let num_batches = reader.num_batches(batch_size).await; - Self { - reader, - batch_size, - num_batches, - batch_idx: 0, +/// Merge multiple partition page / lookup files into a complete metadata file +/// +/// In a distributed environment, each worker node writes partition page / lookup files for the partitions it processes, +/// and this function merges these files into a final metadata file. +pub async fn merge_metadata_files( + store: Arc, + part_page_files: &[String], + part_lookup_files: &[String], + prefetch_batch: Option, +) -> Result<()> { + if part_lookup_files.is_empty() || part_page_files.is_empty() { + return Err(Error::Internal { + message: "No partition files provided for merging".to_string(), + location: location!(), + }); + } + + // Step 1: Create lookup map for page files by partition ID + let mut page_files_map = HashMap::new(); + for page_file in part_page_files { + let partition_id = extract_partition_id(page_file)?; + page_files_map.insert(partition_id, page_file); + } + + // Step 2: Validate that all lookup files have corresponding page files + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + if !page_files_map.contains_key(&partition_id) { + return Err(Error::Internal { + message: format!( + "No corresponding page file found for lookup file: {} (partition_id: {})", + lookup_file, partition_id + ), + location: location!(), + }); } } + + // Step 3: Extract metadata from lookup files + let first_lookup_reader = store.open_index_file(&part_lookup_files[0]).await?; + let batch_size = first_lookup_reader + .schema() + .metadata + .get(BATCH_SIZE_META_KEY) + .map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE)) + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE); + + // Get the value type from lookup schema (min column) + let lookup_batch = first_lookup_reader.read_range(0..1, None).await?; + let value_type = lookup_batch.column(0).data_type().clone(); + + // Get page schema first + let partition_id = extract_partition_id(part_lookup_files[0].as_str())?; + let page_file = page_files_map.get(&partition_id).unwrap(); + let page_reader = store.open_index_file(page_file).await?; + let page_schema = page_reader.schema().clone(); + + let arrow_schema = Arc::new(Schema::from(&page_schema)); + let mut page_file = store + .new_index_file(BTREE_PAGES_NAME, arrow_schema.clone()) + .await?; + + let mut prefetch_config = PrefetchConfig::default(); + if prefetch_batch.is_some() { + prefetch_config = prefetch_config.with_prefetch_batch(prefetch_batch.unwrap()); + } + + let lookup_entries = merge_page( + part_lookup_files, + &page_files_map, + &store, + batch_size, + &mut page_file, + arrow_schema.clone(), + prefetch_config, + ) + .await?; + + page_file.finish().await?; + + // Step 4: Generate new lookup file based on reorganized pages + // Add batch_size to schema metadata + let mut metadata = HashMap::new(); + metadata.insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); + + let lookup_schema_with_metadata = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("min", value_type.clone(), true), + Field::new("max", value_type, true), + Field::new("null_count", DataType::UInt32, false), + Field::new("page_idx", DataType::UInt32, false), + ], + metadata, + )); + + let lookup_batch = RecordBatch::try_new( + lookup_schema_with_metadata.clone(), + vec![ + ScalarValue::iter_to_array(lookup_entries.iter().map(|(min, _, _, _)| min.clone()))?, + ScalarValue::iter_to_array(lookup_entries.iter().map(|(_, max, _, _)| max.clone()))?, + Arc::new(UInt32Array::from_iter_values( + lookup_entries + .iter() + .map(|(_, _, null_count, _)| *null_count), + )), + Arc::new(UInt32Array::from_iter_values( + lookup_entries.iter().map(|(_, _, _, page_idx)| *page_idx), + )), + ], + )?; + + let mut lookup_file = store + .new_index_file(BTREE_LOOKUP_NAME, lookup_schema_with_metadata) + .await?; + lookup_file.write_record_batch(lookup_batch).await?; + lookup_file.finish().await?; + + // After successfully writing the merged files, delete all partition files + // Only perform deletion after files are successfully written, ensuring debug information is not lost in case of failure + cleanup_partition_files(&store, part_lookup_files, part_page_files).await; + + Ok(()) } -impl Stream for IndexReaderStream { - type Item = BoxFuture<'static, Result>; +/// Clean up partition files after successful merge +/// +/// This function safely deletes partition lookup and page files after a successful merge operation. +/// File deletion failures are logged but do not affect the overall success of the merge operation. +async fn cleanup_partition_files( + store: &Arc, + part_lookup_files: &[String], + part_page_files: &[String], +) { + // Clean up partition lookup files + for file_name in part_lookup_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_lookup.lance", + "partition lookup", + ) + .await; + } - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - if this.batch_idx >= this.num_batches { - return std::task::Poll::Ready(None); - } - let batch_num = this.batch_idx; - this.batch_idx += 1; - let reader_copy = this.reader.clone(); - let batch_size = this.batch_size; - let read_task = async move { - reader_copy - .read_record_batch(batch_num as u64, batch_size) - .await - } - .boxed(); - std::task::Poll::Ready(Some(read_task)) + // Clean up partition page files + for file_name in part_page_files { + cleanup_single_file( + store, + file_name, + "part_", + "_page_data.lance", + "partition page", + ) + .await; } } -/// Parameters for a btree index -#[derive(Debug, Serialize, Deserialize)] -pub struct BTreeParameters { - /// The number of rows to include in each zone - pub zone_size: Option, +/// Helper function to clean up a single partition file +/// +/// Performs safety checks on the filename pattern before attempting deletion. +async fn cleanup_single_file( + store: &Arc, + file_name: &str, + expected_prefix: &str, + expected_suffix: &str, + file_type: &str, +) { + // Ensure we only delete files that match the expected pattern (safety check) + if file_name.starts_with(expected_prefix) && file_name.ends_with(expected_suffix) { + match store.delete_index_file(file_name).await { + Ok(()) => { + debug!("Successfully deleted {} file: {}", file_type, file_name); + } + Err(e) => { + // File deletion failures should not affect the overall success of the function + // Log the error but continue processing other files + warn!( + "Failed to delete {} file '{}': {}. \ + This does not affect the merge operation, but may leave \ + partition files that should be cleaned up manually.", + file_type, file_name, e + ); + } + } + } else { + // If the filename doesn't match the expected format, log a warning but don't attempt deletion + warn!( + "Skipping deletion of file '{}' as it does not match the expected \ + {} file pattern ({}*{})", + file_name, file_type, expected_prefix, expected_suffix + ); + } } -struct BTreeTrainingRequest { - parameters: BTreeParameters, - criteria: TrainingCriteria, +/// Prefetch configuration for partition iterators +#[derive(Debug, Clone)] +pub struct PrefetchConfig { + /// Number of batches to prefetch ahead (0 means no prefetching) + pub prefetch_batches: usize, } -impl BTreeTrainingRequest { - pub fn new(parameters: BTreeParameters) -> Self { +impl Default for PrefetchConfig { + fn default() -> Self { Self { - parameters, - // BTree indexes need data sorted by the value column - criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + prefetch_batches: 1, } } } -impl TrainingRequest for BTreeTrainingRequest { - fn as_any(&self) -> &dyn std::any::Any { - self +impl PrefetchConfig { + /// Set the prefetch batch count + pub fn with_prefetch_batch(&self, batch_count: usize) -> Self { + Self { + prefetch_batches: batch_count, + } } +} - fn criteria(&self) -> &TrainingCriteria { - &self.criteria - } +/// Buffer entry for prefetch queue +#[derive(Debug)] +struct BufferEntry { + batch: RecordBatch, + start_row: usize, + end_row: usize, } -#[derive(Debug, Default)] -pub struct BTreeIndexPlugin; +/// Running prefetch task information +#[derive(Debug)] +struct RunningPrefetchTask { + /// Task handle + handle: tokio::task::JoinHandle<()>, + /// Range being prefetched + range: std::ops::Range, +} -#[async_trait] -impl ScalarIndexPlugin for BTreeIndexPlugin { - fn new_training_request( - &self, - params: &str, - field: &Field, - ) -> Result> { - if field.data_type().is_nested() { - return Err(Error::InvalidInput { - source: "A btree index can only be created on a non-nested field.".into(), - location: location!(), - }); - } +/// Check if two ranges overlap +fn ranges_overlap(range1: &std::ops::Range, range2: &std::ops::Range) -> bool { + range1.start < range2.end && range2.start < range1.end +} - let params = serde_json::from_str::(params)?; - Ok(Box::new(BTreeTrainingRequest::new(params))) - } +/// Prefetch state for a partition using task-based prefetching +struct PartitionPrefetchState { + /// Queue of prefetched data + buffer: Arc>>, + /// Reader for this partition + reader: Arc, + /// Total rows in this partition + total_rows: usize, + /// Queue of running prefetch tasks with their ranges + running_tasks: Arc>>, + /// Next position to schedule for prefetch + next_prefetch_position: Arc>, +} - fn provides_exact_answer(&self) -> bool { - true - } +/// Manager for coordinating task-based prefetch across multiple partitions +pub struct PrefetchManager { + /// Prefetch state per partition + partition_states: HashMap, + /// Prefetch configuration + config: PrefetchConfig, +} - fn version(&self) -> u32 { - 0 +impl PrefetchManager { + /// Create a new prefetch manager + pub fn new(config: PrefetchConfig) -> Self { + Self { + partition_states: HashMap::new(), + config, + } } - fn new_query_parser( - &self, - index_name: String, - _index_details: &prost_types::Any, - ) -> Option> { - Some(Box::new(SargableQueryParser::new(index_name, false))) - } + /// Initialize a partition for task-based prefetching + pub fn initialize_partition(&mut self, partition_id: u64, reader: Arc) { + let total_rows = reader.num_rows(); + let buffer = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let running_tasks = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); + let next_prefetch_position = Arc::new(tokio::sync::Mutex::new(0)); - async fn train_index( - &self, - data: SendableRecordBatchStream, - index_store: &dyn IndexStore, - request: Box, - ) -> Result { - let request = request - .as_any() - .downcast_ref::() - .unwrap(); - let value_type = data - .schema() - .field_with_name(VALUE_COLUMN_NAME)? - .data_type() - .clone(); - let flat_index_trainer = FlatIndexMetadata::new(value_type); - train_btree_index( - data, - &flat_index_trainer, - index_store, - request - .parameters - .zone_size - .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), - ) - .await?; - Ok(CreatedIndex { - index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), - index_version: BTREE_INDEX_VERSION, - }) - } + let state = PartitionPrefetchState { + buffer, + reader, + total_rows, + running_tasks, + next_prefetch_position, + }; - async fn load_index( - &self, - index_store: Arc, - _index_details: &prost_types::Any, - frag_reuse_index: Option>, - cache: LanceCache, - ) -> Result> { - Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + self.partition_states.insert(partition_id, state); + debug!( + "Initialized partition {} for task-based prefetching", + partition_id + ); } -} -#[cfg(test)] -mod tests { - use std::sync::atomic::Ordering; - use std::{collections::HashMap, sync::Arc}; + /// Submit a prefetch task for a partition to the thread pool + pub async fn submit_prefetch_task(&self, partition_id: u64, batch_size: usize) -> Result<()> { + if self.config.prefetch_batches == 0 { + return Ok(()); + } - use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; - use arrow_array::FixedSizeListArray; - use arrow_schema::DataType; - use datafusion::{ - execution::{SendableRecordBatchStream, TaskContext}, - physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, - }; - use datafusion_common::{DataFusionError, ScalarValue}; - use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; - use deepsize::DeepSizeOf; - use futures::TryStreamExt; - use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; - use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; - use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; - use lance_io::object_store::ObjectStore; - use object_store::path::Path; - use tempfile::tempdir; + let Some(state) = self.partition_states.get(&partition_id) else { + return Ok(()); + }; - use crate::metrics::LocalMetricsCollector; - use crate::{ - metrics::NoOpMetricsCollector, - scalar::{ - btree::{BTreeIndex, BTREE_PAGES_NAME}, - flat::FlatIndexMetadata, - lance_format::LanceIndexStore, - IndexStore, SargableQuery, ScalarIndex, SearchResult, - }, - }; + let reader = state.reader.clone(); + let buffer = state.buffer.clone(); + let running_tasks = state.running_tasks.clone(); + let next_prefetch_position = state.next_prefetch_position.clone(); + let total_rows = state.total_rows; + let effective_batch_size = self.config.prefetch_batches * batch_size; - use super::{train_btree_index, OrderableScalarValue}; + const MAX_BUFFER_SIZE: usize = 4; + const MAX_RUNNING_TASKS: usize = 2; - #[test] - fn test_scalar_value_size() { - let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); - let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( - FixedSizeListArray::from_iter_primitive::( - vec![Some(vec![Some(0); 128])], - 128, - ), - ))) - .deep_size_of(); + // Clean up completed tasks and check limits + { + let mut tasks_guard = running_tasks.lock().await; - // deep_size_of should account for the rust type overhead - assert!(size_of_i32 > 4); - assert!(size_of_many_i32 > 128 * 4); + // Remove completed tasks from the front + while let Some(task) = tasks_guard.front() { + if task.handle.is_finished() { + tasks_guard.pop_front(); + } else { + break; + } + } + + // Check if we have too many running tasks + if tasks_guard.len() >= MAX_RUNNING_TASKS { + debug!( + "Skipping prefetch for partition {} - too many running tasks ({})", + partition_id, + tasks_guard.len() + ); + return Ok(()); + } + + // Check if any running task already covers to the end of file + for task in tasks_guard.iter() { + if task.range.end >= total_rows { + debug!( + "Skipping prefetch for partition {} - task already covers to EOF (range {}..{})", + partition_id, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // Check if buffer is full + { + let buffer_guard = buffer.lock().await; + if buffer_guard.len() >= MAX_BUFFER_SIZE { + debug!( + "Skipping prefetch for partition {} - buffer full", + partition_id + ); + return Ok(()); + } + } + + // Determine the next range to prefetch + let next_range = { + let mut pos_guard = next_prefetch_position.lock().await; + let start_pos = *pos_guard; + + if start_pos >= total_rows { + debug!( + "Skipping prefetch for partition {} - no more data to prefetch", + partition_id + ); + return Ok(()); + } + + let end_pos = std::cmp::min(start_pos + effective_batch_size, total_rows); + *pos_guard = end_pos; // Update next prefetch position + start_pos..end_pos + }; + + // Check if this range is already being prefetched + { + let tasks_guard = running_tasks.lock().await; + + // Check for range overlap + for task in tasks_guard.iter() { + if ranges_overlap(&task.range, &next_range) { + debug!( + "Skipping prefetch for partition {} - range {}..{} overlaps with running task {}..{}", + partition_id, next_range.start, next_range.end, task.range.start, task.range.end + ); + return Ok(()); + } + } + } + + // All checks passed, create the actual prefetch task (only this part is async) + let range_clone = next_range.clone(); + let running_tasks_for_cleanup = running_tasks.clone(); + + let prefetch_task = spawn_btree_prefetch(async move { + // Perform the actual read + match reader.read_range(range_clone.clone(), None).await { + Ok(batch) => { + let entry = BufferEntry { + batch, + start_row: range_clone.start, + end_row: range_clone.end, + }; + + // Add to buffer + { + let mut buffer_guard = buffer.lock().await; + buffer_guard.push_back(entry); + } + + debug!( + "Prefetched {} rows ({}..{}) for partition {}", + range_clone.end - range_clone.start, + range_clone.start, + range_clone.end, + partition_id + ); + } + Err(err) => { + warn!( + "Prefetch task failed for partition {} range {}..{}: {}", + partition_id, range_clone.start, range_clone.end, err + ); + } + } + + // Remove this task from running tasks when completed + { + let mut tasks_guard = running_tasks_for_cleanup.lock().await; + tasks_guard.retain(|task| !task.handle.is_finished()); + } + }); + + // Add the task to running tasks + { + let mut tasks_guard = running_tasks.lock().await; + tasks_guard.push_back(RunningPrefetchTask { + handle: prefetch_task, + range: next_range.clone(), + }); + } + + debug!( + "Submitted prefetch task for partition {} range {}..{}", + partition_id, next_range.start, next_range.end + ); + + Ok(()) + } + + /// Get data from buffer or fallback to direct read + pub async fn get_data_with_fallback( + &self, + partition_id: u64, + start_row: usize, + end_row: usize, + ) -> Result { + if let Some(state) = self.partition_states.get(&partition_id) { + // First try to get from buffer + { + let mut buffer_guard = state.buffer.lock().await; + + // Remove outdated entries from the front + while let Some(entry) = buffer_guard.front() { + if entry.end_row <= start_row { + buffer_guard.pop_front(); + } else { + break; + } + } + + // Check if we have suitable data in buffer + if let Some(entry) = buffer_guard.front() { + if entry.start_row <= start_row && entry.end_row >= end_row { + // Found matching data, extract it + let entry = buffer_guard.pop_front().unwrap(); + drop(buffer_guard); + + let slice_start = start_row - entry.start_row; + let slice_len = end_row - start_row; + + debug!( + "Using buffered data for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + return Ok(entry.batch.slice(slice_start, slice_len)); + } + } + } + + // Fallback to direct read + debug!( + "Direct read fallback for partition {} ({}..{})", + partition_id, start_row, end_row + ); + + state.reader.read_range(start_row..end_row, None).await + } else { + Err(Error::Internal { + message: format!("Partition {} not found in prefetch manager", partition_id), + location: location!(), + }) + } + } +} + +/// Simplified partition iterator with immediate loading since all partitions need to be accessed +struct PartitionIterator { + reader: Arc, + current_batch: Option, + current_position: usize, + rows_read: usize, + partition_id: u64, + batch_size: u64, +} + +impl PartitionIterator { + async fn new( + store: Arc, + page_file_name: String, + partition_id: u64, + batch_size: u64, + ) -> Result { + let reader = store.open_index_file(&page_file_name).await?; + Ok(Self { + reader, + current_batch: None, + current_position: 0, + rows_read: 0, + partition_id, + batch_size, + }) + } + + /// Get the next element, working with the prefetch manager + async fn next( + &mut self, + prefetch_manager: &PrefetchManager, + ) -> Result> { + // Load new batch if current one is exhausted + if self.needs_new_batch() { + if self.rows_read >= self.reader.num_rows() { + return Ok(None); + } + self.load_next_batch(prefetch_manager).await?; + + // Submit next prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } else { + // Check if we've read half of the current batch, submit next prefetch task + let batch_half = self.current_batch.as_ref().unwrap().num_rows() / 2; + if self.current_position == batch_half && batch_half > 0 { + if let Err(err) = prefetch_manager + .submit_prefetch_task(self.partition_id, self.batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + self.partition_id, err + ); + } + } + } + + // Extract next value from current batch + if let Some(batch) = &self.current_batch { + let value = ScalarValue::try_from_array(batch.column(0), self.current_position)?; + let row_id = ScalarValue::try_from_array(batch.column(1), self.current_position)?; + self.current_position += 1; + self.rows_read += 1; + Ok(Some((value, row_id))) + } else { + Ok(None) + } + } + + /// Check if we need to load a new batch + fn needs_new_batch(&self) -> bool { + self.current_batch.is_none() + || self.current_position >= self.current_batch.as_ref().unwrap().num_rows() + } + + async fn load_next_batch(&mut self, prefetch_manager: &PrefetchManager) -> Result<()> { + let remaining_rows = self.reader.num_rows() - self.rows_read; + if remaining_rows == 0 { + self.current_batch = None; + return Ok(()); + } + + let rows_to_read = std::cmp::min(self.batch_size as usize, remaining_rows); + let end_row = self.rows_read + rows_to_read; + + // Use the new fallback mechanism - try buffer first, then direct read + let batch = prefetch_manager + .get_data_with_fallback(self.partition_id, self.rows_read, end_row) + .await?; + + self.current_batch = Some(batch); + self.current_position = 0; + + Ok(()) + } + + fn get_reader(&self) -> Arc { + self.reader.clone() + } +} + +/// Heap elements, used for priority queues in multi-way merging +#[derive(Debug)] +struct HeapElement { + value: ScalarValue, + row_id: ScalarValue, + partition_id: u64, +} + +impl PartialEq for HeapElement { + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } +} + +impl Eq for HeapElement {} + +impl PartialOrd for HeapElement { + fn partial_cmp(&self, other: &Self) -> Option { + // Note: BinaryHeap is a maximum heap, we need a minimum heap, + // so reverse the comparison result + other.value.partial_cmp(&self.value) + } +} + +impl Ord for HeapElement { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap_or(Ordering::Equal) + } +} + +async fn merge_page( + part_lookup_files: &[String], + page_files_map: &HashMap, + store: &Arc, + batch_size: u64, + page_file: &mut Box, + arrow_schema: Arc, + prefetch_config: PrefetchConfig, +) -> Result> { + let mut lookup_entries = Vec::new(); + let mut page_idx = 0u32; + + debug!( + "Starting multi-way merge with {} partitions using prefetch manager", + part_lookup_files.len() + ); + + // Create prefetch manager + let mut prefetch_manager = PrefetchManager::new(prefetch_config.clone()); + + // Directly create iterators and read first element + let mut partition_map = HashMap::new(); + let mut heap = BinaryHeap::new(); + + debug!("Initializing {} partitions", part_lookup_files.len()); + + // Initialize all partitions + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + let page_file_name = page_files_map + .get(&partition_id) + .ok_or_else(|| Error::Internal { + message: format!("Page file not found for partition ID: {}", partition_id), + location: location!(), + })? + .to_string(); + + let mut iterator = + PartitionIterator::new(store.clone(), page_file_name, partition_id, batch_size).await?; + + // Initialize partition in prefetch manager + let reader = iterator.get_reader(); + prefetch_manager.initialize_partition(partition_id, reader); + + // Submit initial prefetch task + if let Err(err) = prefetch_manager + .submit_prefetch_task(partition_id, batch_size as usize) + .await + { + warn!( + "Failed to submit prefetch task for partition {}: {}", + partition_id, err + ); + } + + let first_element = iterator.next(&prefetch_manager).await?; + + if let Some((value, row_id)) = first_element { + // Put the first element into the heap + heap.push(HeapElement { + value, + row_id, + partition_id, + }); + } + + partition_map.insert(partition_id, iterator); + } + + debug!( + "Initialized {} partitions, heap size: {}", + partition_map.len(), + heap.len() + ); + + let mut current_batch_rows = Vec::with_capacity(batch_size as usize); + let mut total_merged = 0usize; + + // Multi-way merge main loop + while let Some(min_element) = heap.pop() { + // Add current minimum element to batch + current_batch_rows.push((min_element.value, min_element.row_id)); + total_merged += 1; + + // Read next element from corresponding partition + if let Some(iterator) = partition_map.get_mut(&min_element.partition_id) { + if let Some((next_value, next_row_id)) = iterator.next(&prefetch_manager).await? { + heap.push(HeapElement { + value: next_value, + row_id: next_row_id, + partition_id: min_element.partition_id, + }); + } + } + + // Write when batch reaches specified size + if current_batch_rows.len() >= batch_size as usize { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + } + + // Write the remaining data + if !current_batch_rows.is_empty() { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + + debug!( + "Completed multi-way merge: merged {} rows into {} lookup entries", + total_merged, + lookup_entries.len() + ); + Ok(lookup_entries) +} + +/// Helper function to prepare batch data in parallel +async fn prepare_batch_data( + batch_rows: Vec<(ScalarValue, ScalarValue)>, + arrow_schema: Arc, + page_idx: u32, +) -> Result<(RecordBatch, (ScalarValue, ScalarValue, u32, u32))> { + if batch_rows.is_empty() { + return Err(Error::Internal { + message: "Cannot prepare empty batch".to_string(), + location: location!(), + }); + } + + // Parallelize data preparation + let (values, row_ids): (Vec<_>, Vec<_>) = batch_rows.into_iter().unzip(); + + // Convert to arrays in parallel using rayon or manually spawn tasks + let values_array = ScalarValue::iter_to_array(values.into_iter())?; + let row_ids_array = ScalarValue::iter_to_array(row_ids.into_iter())?; + + let batch = RecordBatch::try_new(arrow_schema, vec![values_array, row_ids_array])?; + + // Calculate min/max/null_count for lookup entry + let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; + let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; + let null_count = batch.column(0).null_count() as u32; + + let lookup_entry = (min_val, max_val, null_count, page_idx); + + Ok((batch, lookup_entry)) +} + +/// Helper function to write a batch and create lookup entry +async fn write_batch_and_lookup_entry( + batch_rows: &mut Vec<(ScalarValue, ScalarValue)>, + page_file: &mut Box, + arrow_schema: &Arc, + lookup_entries: &mut Vec<(ScalarValue, ScalarValue, u32, u32)>, + page_idx: &mut u32, +) -> Result<()> { + if batch_rows.is_empty() { + return Ok(()); + } + + // Take ownership of the batch data + let batch_data = std::mem::take(batch_rows); + let current_page_idx = *page_idx; + + // Prepare batch data + let (batch, lookup_entry) = + prepare_batch_data(batch_data, arrow_schema.clone(), current_page_idx).await?; + + lookup_entries.push(lookup_entry); + page_file.write_record_batch(batch).await?; + *page_idx += 1; + + Ok(()) +} + +pub(crate) fn part_page_data_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_PAGES_NAME) +} + +pub(crate) fn part_lookup_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME) +} + +/// A stream that reads the original training data back out of the index +/// +/// This is used for updating the index +struct IndexReaderStream { + reader: Arc, + batch_size: u64, + num_batches: u32, + batch_idx: u32, +} + +impl IndexReaderStream { + async fn new(reader: Arc, batch_size: u64) -> Self { + let num_batches = reader.num_batches(batch_size).await; + Self { + reader, + batch_size, + num_batches, + batch_idx: 0, + } + } +} + +impl Stream for IndexReaderStream { + type Item = BoxFuture<'static, Result>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + if this.batch_idx >= this.num_batches { + return std::task::Poll::Ready(None); + } + let batch_num = this.batch_idx; + this.batch_idx += 1; + let reader_copy = this.reader.clone(); + let batch_size = this.batch_size; + let read_task = async move { + reader_copy + .read_record_batch(batch_num as u64, batch_size) + .await + } + .boxed(); + std::task::Poll::Ready(Some(read_task)) + } +} + +/// Parameters for a btree index +#[derive(Debug, Serialize, Deserialize)] +pub struct BTreeParameters { + /// The number of rows to include in each zone + pub zone_size: Option, +} + +struct BTreeTrainingRequest { + parameters: BTreeParameters, + criteria: TrainingCriteria, +} + +impl BTreeTrainingRequest { + pub fn new(parameters: BTreeParameters) -> Self { + Self { + parameters, + // BTree indexes need data sorted by the value column + criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(), + } + } +} + +impl TrainingRequest for BTreeTrainingRequest { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn criteria(&self) -> &TrainingCriteria { + &self.criteria + } +} + +#[derive(Debug, Default)] +pub struct BTreeIndexPlugin; + +#[async_trait] +impl ScalarIndexPlugin for BTreeIndexPlugin { + fn new_training_request( + &self, + params: &str, + field: &Field, + ) -> Result> { + if field.data_type().is_nested() { + return Err(Error::InvalidInput { + source: "A btree index can only be created on a non-nested field.".into(), + location: location!(), + }); + } + + let params = serde_json::from_str::(params)?; + Ok(Box::new(BTreeTrainingRequest::new(params))) + } + + fn provides_exact_answer(&self) -> bool { + true + } + + fn version(&self) -> u32 { + 0 + } + + fn new_query_parser( + &self, + index_name: String, + _index_details: &prost_types::Any, + ) -> Option> { + Some(Box::new(SargableQueryParser::new(index_name, false))) + } + + async fn train_index( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + ) -> Result { + let request = request + .as_any() + .downcast_ref::() + .unwrap(); + let value_type = data + .schema() + .field_with_name(VALUE_COLUMN_NAME)? + .data_type() + .clone(); + let flat_index_trainer = FlatIndexMetadata::new(value_type); + train_btree_index( + data, + &flat_index_trainer, + index_store, + request + .parameters + .zone_size + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE), + fragment_ids, + ) + .await?; + Ok(CreatedIndex { + index_details: prost_types::Any::from_msg(&pb::BTreeIndexDetails::default()).unwrap(), + index_version: BTREE_INDEX_VERSION, + }) + } + + async fn load_index( + &self, + index_store: Arc, + _index_details: &prost_types::Any, + frag_reuse_index: Option>, + cache: LanceCache, + ) -> Result> { + Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + use std::{collections::HashMap, sync::Arc}; + + use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; + use arrow_array::FixedSizeListArray; + use arrow_schema::DataType; + use datafusion::{ + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, + }; + use datafusion_common::{DataFusionError, ScalarValue}; + use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; + use deepsize::DeepSizeOf; + use futures::TryStreamExt; + use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; + use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; + use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use tempfile::tempdir; + + use crate::metrics::LocalMetricsCollector; + use crate::{ + metrics::NoOpMetricsCollector, + scalar::{ + btree::{BTreeIndex, BTREE_PAGES_NAME}, + flat::FlatIndexMetadata, + lance_format::LanceIndexStore, + IndexStore, SargableQuery, ScalarIndex, SearchResult, + }, + }; + + use super::{ + part_lookup_file_path, part_page_data_file_path, train_btree_index, OrderableScalarValue, + DEFAULT_BTREE_BATCH_SIZE, + }; + + #[test] + fn test_scalar_value_size() { + let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); + let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(0); 128])], + 128, + ), + ))) + .deep_size_of(); + + // deep_size_of should account for the rust type overhead + assert!(size_of_i32 > 4); + assert!(size_of_many_i32 > 128 * 4); + } + + #[tokio::test] + async fn test_null_ids() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Generate 50,000 rows of random data with 80% nulls + let stream = gen_batch() + .col( + "value", + array::rand::().with_nulls(&[true, false, false, false, false]), + ) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(5000), BatchCount::from(10)); + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(index.page_lookup.null_pages.len(), 10); + + let remap_dir = Arc::new(tempdir().unwrap()); + let remap_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(remap_dir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + // Remap with a no-op mapping. The remapped index should be identical to the original + index + .remap(&HashMap::default(), remap_store.as_ref()) + .await + .unwrap(); + + let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + assert_eq!(remap_index.page_lookup, index.page_lookup); + + let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + + assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + + let original_data = original_pages + .read_record_batch(0, original_pages.num_rows() as u64) + .await + .unwrap(); + let remapped_data = remapped_pages + .read_record_batch(0, remapped_pages.num_rows() as u64) + .await + .unwrap(); + + assert_eq!(original_data, remapped_data); + } + + #[tokio::test] + async fn test_nan_ordering() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let values = vec![ + 0.0, + 1.0, + 2.0, + 3.0, + f64::NAN, + f64::NEG_INFINITY, + f64::INFINITY, + ]; + + // This is a bit overkill but we've had bugs in the past where DF's sort + // didn't agree with Arrow's sort so we do an end-to-end test here + // and use DF to sort the data like we would in a real dataset. + let data = gen_batch() + .col("value", array::cycle::(values.clone())) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(10), BatchCount::from(100)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + .await + .unwrap(); + + for (idx, value) in values.into_iter().enumerate() { + let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert_eq!( + result, + SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) + ); + } + } + + #[tokio::test] + async fn test_page_cache() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let data = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(1000), BatchCount::from(10)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + .await + .unwrap(); + + let index = BTreeIndex::load( + test_store, + None, + LanceCache::with_capacity(100 * 1024 * 1024), + ) + .await + .unwrap(); + + let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); + let metrics = LocalMetricsCollector::default(); + let query1 = index.search(&query, &metrics); + let query2 = index.search(&query, &metrics); + tokio::join!(query1, query2).0.unwrap(); + assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + } + + /// Test that fragment-based btree index construction produces exactly the same results as building a complete index + #[tokio::test] + async fn test_fragment_btree_index_consistency() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Method 1: Build complete index directly using the same data + // Create deterministic data for comparison - use 2 * DEFAULT_BTREE_BATCH_SIZE for testing + let total_count = (2 * DEFAULT_BTREE_BATCH_SIZE) as u64; + let full_data_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(total_count / 2), BatchCount::from(2)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using the same data split into fragments + // Create fragment 1 index - first half of the data (0 to DEFAULT_BTREE_BATCH_SIZE-1) + let half_count = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), // fragment_id = 1 + ) + .await + .unwrap(); + + // Create fragment 2 index - second half of the data (DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1) + let start_val = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_second_half: Vec = (start_val..end_val).collect(); + let row_ids_second_half: Vec = (start_val as u64..end_val as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_second_half)) + .col("_rowid", array::cycle::(row_ids_second_half)) + .into_df_stream(RowCount::from(half_count), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), // fragment_id = 2 + ) + .await + .unwrap(); + + // Merge the fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) + .await + .unwrap(); + + // Test queries one by one to identify the exact problem + + // Test 1: Query for value 0 (should be in first page) + let query_0 = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_0 = full_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_0 = merged_index + .search(&query_0, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(full_result_0, merged_result_0, "Query for value 0 failed"); + + // Test 2: Query for value in middle of first batch (should be in first page) + let mid_first_batch = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_first = SargableQuery::Equals(ScalarValue::Int32(Some(mid_first_batch))); + let full_result_mid_first = full_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_first = merged_index + .search(&query_mid_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_first, merged_result_mid_first, + "Query for value {} failed", + mid_first_batch + ); + + // Test 3: Query for first value in second batch (should be in second page) + let first_second_batch = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_first_second = + SargableQuery::Equals(ScalarValue::Int32(Some(first_second_batch))); + let full_result_first_second = full_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_first_second = merged_index + .search(&query_first_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_first_second, merged_result_first_second, + "Query for value {} failed", + first_second_batch + ); + + // Test 4: Query for value in middle of second batch (should be in second page) + let mid_second_batch = (DEFAULT_BTREE_BATCH_SIZE + DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_mid_second = SargableQuery::Equals(ScalarValue::Int32(Some(mid_second_batch))); + + let full_result_mid_second = full_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_mid_second = merged_index + .search(&query_mid_second, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_mid_second, merged_result_mid_second, + "Query for value {} failed", + mid_second_batch + ); } #[tokio::test] - async fn test_null_ids() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + async fn test_fragment_btree_index_boundary_queries() { + // Setup stores for both indexes + let full_tmpdir = Arc::new(tempdir().unwrap()); + let full_store = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), + Path::from_filesystem_path(full_tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - // Generate 50,000 rows of random data with 80% nulls - let stream = gen_batch() - .col( - "value", - array::rand::().with_nulls(&[true, false, false, false, false]), - ) + let fragment_tmpdir = Arc::new(tempdir().unwrap()); + let fragment_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(fragment_tmpdir.path()).unwrap(), + Arc::new(LanceCache::no_cache()), + )); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); + + // Use 3 * DEFAULT_BTREE_BATCH_SIZE for more comprehensive boundary testing + let total_count = (3 * DEFAULT_BTREE_BATCH_SIZE) as u64; + + // Method 1: Build complete index directly + let full_data_gen = gen_batch() + .col("value", array::step::()) .col("_rowid", array::step::()) - .into_df_stream(RowCount::from(5000), BatchCount::from(10)); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + .into_df_stream(RowCount::from(total_count / 3), BatchCount::from(3)); + let full_data_source = Box::pin(RecordBatchStreamAdapter::new( + full_data_gen.schema(), + full_data_gen, + )); + + train_btree_index( + full_data_source, + &sub_index_trainer, + full_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + None, + ) + .await + .unwrap(); + + // Method 2: Build fragment-based index using 3 fragments + // Fragment 1: 0 to DEFAULT_BTREE_BATCH_SIZE-1 + let fragment_size = DEFAULT_BTREE_BATCH_SIZE; + let fragment1_gen = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment1_gen.schema(), + fragment1_gen, + )); + + train_btree_index( + fragment1_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![1]), + ) + .await + .unwrap(); + + // Fragment 2: DEFAULT_BTREE_BATCH_SIZE to 2*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val2 = DEFAULT_BTREE_BATCH_SIZE as i32; + let end_val2 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment2: Vec = (start_val2..end_val2).collect(); + let row_ids_fragment2: Vec = (start_val2 as u64..end_val2 as u64).collect(); + let fragment2_gen = gen_batch() + .col("value", array::cycle::(values_fragment2)) + .col("_rowid", array::cycle::(row_ids_fragment2)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment2_gen.schema(), + fragment2_gen, + )); + + train_btree_index( + fragment2_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![2]), + ) + .await + .unwrap(); - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000) + // Fragment 3: 2*DEFAULT_BTREE_BATCH_SIZE to 3*DEFAULT_BTREE_BATCH_SIZE-1 + let start_val3 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let end_val3 = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let values_fragment3: Vec = (start_val3..end_val3).collect(); + let row_ids_fragment3: Vec = (start_val3 as u64..end_val3 as u64).collect(); + let fragment3_gen = gen_batch() + .col("value", array::cycle::(values_fragment3)) + .col("_rowid", array::cycle::(row_ids_fragment3)) + .into_df_stream(RowCount::from(fragment_size), BatchCount::from(1)); + let fragment3_data_source = Box::pin(RecordBatchStreamAdapter::new( + fragment3_gen.schema(), + fragment3_gen, + )); + + train_btree_index( + fragment3_data_source, + &sub_index_trainer, + fragment_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE, + Some(vec![3]), + ) + .await + .unwrap(); + + // Merge all fragment files + let part_page_files = vec![ + part_page_data_file_path(1 << 32), + part_page_data_file_path(2 << 32), + part_page_data_file_path(3 << 32), + ]; + + let part_lookup_files = vec![ + part_lookup_file_path(1 << 32), + part_lookup_file_path(2 << 32), + part_lookup_file_path(3 << 32), + ]; + + super::merge_metadata_files( + fragment_store.clone(), + &part_page_files, + &part_lookup_files, + Option::from(1usize), + ) + .await + .unwrap(); + + // Load both indexes + let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - let index = BTreeIndex::load(test_store.clone(), None, LanceCache::no_cache()) + let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) .await .unwrap(); - assert_eq!(index.page_lookup.null_pages.len(), 10); + // === Boundary Value Tests === - let remap_dir = Arc::new(tempdir().unwrap()); - let remap_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(remap_dir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 1: Query minimum value (boundary: data start) + let query_min = SargableQuery::Equals(ScalarValue::Int32(Some(0))); + let full_result_min = full_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_min = merged_index + .search(&query_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_min, merged_result_min, + "Query for minimum value 0 failed" + ); - // Remap with a no-op mapping. The remapped index should be identical to the original - index - .remap(&HashMap::default(), remap_store.as_ref()) + // Test 2: Query maximum value (boundary: data end) + let max_val = (3 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val))); + let full_result_max = full_index + .search(&query_max, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_max = merged_index + .search(&query_max, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_max, merged_result_max, + "Query for maximum value {} failed", + max_val + ); - let remap_index = BTreeIndex::load(remap_store.clone(), None, LanceCache::no_cache()) + // Test 3: Query fragment boundary value (last value of first fragment) + let fragment1_last = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag1_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment1_last))); + let full_result_frag1_last = full_index + .search(&query_frag1_last, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_frag1_last = merged_index + .search(&query_frag1_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag1_last, merged_result_frag1_last, + "Query for fragment 1 last value {} failed", + fragment1_last + ); - assert_eq!(remap_index.page_lookup, index.page_lookup); + // Test 4: Query fragment boundary value (first value of second fragment) + let fragment2_first = DEFAULT_BTREE_BATCH_SIZE as i32; + let query_frag2_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_first))); + let full_result_frag2_first = full_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_first = merged_index + .search(&query_frag2_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_first, merged_result_frag2_first, + "Query for fragment 2 first value {} failed", + fragment2_first + ); - let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); - let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + // Test 5: Query fragment boundary value (last value of second fragment) + let fragment2_last = (2 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_frag2_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_last))); + let full_result_frag2_last = full_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag2_last = merged_index + .search(&query_frag2_last, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag2_last, merged_result_frag2_last, + "Query for fragment 2 last value {} failed", + fragment2_last + ); - assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + // Test 6: Query fragment boundary value (first value of third fragment) + let fragment3_first = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_frag3_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment3_first))); + let full_result_frag3_first = full_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_frag3_first = merged_index + .search(&query_frag3_first, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_frag3_first, merged_result_frag3_first, + "Query for fragment 3 first value {} failed", + fragment3_first + ); - let original_data = original_pages - .read_record_batch(0, original_pages.num_rows() as u64) + // === Non-existent Value Tests === + + // Test 7: Query value below minimum + let query_below_min = SargableQuery::Equals(ScalarValue::Int32(Some(-1))); + let full_result_below = full_index + .search(&query_below_min, &NoOpMetricsCollector) .await .unwrap(); - let remapped_data = remapped_pages - .read_record_batch(0, remapped_pages.num_rows() as u64) + let merged_result_below = merged_index + .search(&query_below_min, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_below, merged_result_below, + "Query for value below minimum (-1) failed" + ); + + // Test 8: Query value above maximum + let query_above_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val + 1))); + let full_result_above = full_index + .search(&query_above_max, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_above = merged_index + .search(&query_above_max, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_above, + merged_result_above, + "Query for value above maximum ({}) failed", + max_val + 1 + ); - assert_eq!(original_data, remapped_data); - } + // === Range Query Tests === - #[tokio::test] - async fn test_nan_ordering() { - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); + // Test 9: Cross-fragment range query (from first fragment to second fragment) + let range_start = (DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let range_end = (DEFAULT_BTREE_BATCH_SIZE + 100) as i32; + let query_cross_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(range_end))), + ); + let full_result_cross = full_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_cross = merged_index + .search(&query_cross_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_cross, merged_result_cross, + "Cross-fragment range query [{}, {}] failed", + range_start, range_end + ); - let values = vec![ - 0.0, - 1.0, - 2.0, - 3.0, - f64::NAN, - f64::NEG_INFINITY, - f64::INFINITY, - ]; + // Test 10: Range query within single fragment + let single_frag_start = 100i32; + let single_frag_end = 200i32; + let query_single_frag = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(single_frag_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(single_frag_end))), + ); + let full_result_single = full_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_single = merged_index + .search(&query_single_frag, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_single, merged_result_single, + "Single fragment range query [{}, {}] failed", + single_frag_start, single_frag_end + ); - // This is a bit overkill but we've had bugs in the past where DF's sort - // didn't agree with Arrow's sort so we do an end-to-end test here - // and use DF to sort the data like we would in a real dataset. - let data = gen_batch() - .col("value", array::cycle::(values.clone())) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(10), BatchCount::from(100)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + // Test 11: Large range query spanning all fragments + let large_range_start = 100i32; + let large_range_end = (3 * DEFAULT_BTREE_BATCH_SIZE - 100) as i32; + let query_large_range = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(large_range_start))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(large_range_end))), + ); + let full_result_large = full_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_large = merged_index + .search(&query_large_range, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_large, merged_result_large, + "Large range query [{}, {}] failed", + large_range_start, large_range_end + ); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + // === Range Boundary Query Tests === - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) + // Test 12: Less than query (implemented using range query, from minimum to specified value) + let lt_val = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32; + let query_lt = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(lt_val))), + ); + let full_result_lt = full_index + .search(&query_lt, &NoOpMetricsCollector) .await .unwrap(); + let merged_result_lt = merged_index + .search(&query_lt, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lt, merged_result_lt, + "Less than query (<{}) failed", + lt_val + ); - let index = BTreeIndex::load(test_store, None, LanceCache::no_cache()) + // Test 13: Greater than query (implemented using range query, from specified value to maximum) + let gt_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let max_range_val = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gt = SargableQuery::Range( + std::collections::Bound::Excluded(ScalarValue::Int32(Some(gt_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gt = full_index + .search(&query_gt, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gt = merged_index + .search(&query_gt, &NoOpMetricsCollector) .await .unwrap(); + assert_eq!( + full_result_gt, merged_result_gt, + "Greater than query (>{}) failed", + gt_val + ); - for (idx, value) in values.into_iter().enumerate() { - let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); - let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) - ); - } + // Test 14: Less than or equal query (implemented using range query, including boundary value) + let lte_val = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32; + let query_lte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(0))), + std::collections::Bound::Included(ScalarValue::Int32(Some(lte_val))), + ); + let full_result_lte = full_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_lte = merged_index + .search(&query_lte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_lte, merged_result_lte, + "Less than or equal query (<={}) failed", + lte_val + ); + + // Test 15: Greater than or equal query (implemented using range query, including boundary value) + let gte_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32; + let query_gte = SargableQuery::Range( + std::collections::Bound::Included(ScalarValue::Int32(Some(gte_val))), + std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))), + ); + let full_result_gte = full_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + let merged_result_gte = merged_index + .search(&query_gte, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!( + full_result_gte, merged_result_gte, + "Greater than or equal query (>={}) failed", + gte_val + ); + } + + #[test] + fn test_extract_partition_id() { + // Test valid partition file names + assert_eq!( + super::extract_partition_id("part_123_page_data.lance").unwrap(), + 123 + ); + assert_eq!( + super::extract_partition_id("part_456_page_lookup.lance").unwrap(), + 456 + ); + assert_eq!( + super::extract_partition_id("part_4294967296_page_data.lance").unwrap(), + 4294967296 + ); + + // Test invalid file names + assert!(super::extract_partition_id("invalid_filename.lance").is_err()); + assert!(super::extract_partition_id("part_abc_page_data.lance").is_err()); + assert!(super::extract_partition_id("part_123").is_err()); + assert!(super::extract_partition_id("part_").is_err()); } #[tokio::test] - async fn test_page_cache() { + async fn test_cleanup_partition_files() { + use crate::scalar::lance_format::LanceIndexStore; + use lance_core::cache::LanceCache; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use std::sync::Arc; + use tempfile::tempdir; + + // Create a test store let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( + let test_store: Arc = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), Path::from_filesystem_path(tmpdir.path()).unwrap(), Arc::new(LanceCache::no_cache()), )); - let data = gen_batch() - .col("value", array::step::()) - .col("_rowid", array::step::()) - .into_df_exec(RowCount::from(1000), BatchCount::from(10)); - let schema = data.schema(); - let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); - let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); - let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); - let stream = break_stream(stream, 64); - let stream = stream.map_err(DataFusionError::from); - let stream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + // Test files with different patterns + let lookup_files = vec![ + "part_123_page_lookup.lance".to_string(), + "invalid_lookup_file.lance".to_string(), + "part_456_page_lookup.lance".to_string(), + ]; - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64) - .await - .unwrap(); + let page_files = vec![ + "part_123_page_data.lance".to_string(), + "invalid_page_file.lance".to_string(), + "part_456_page_data.lance".to_string(), + ]; - let index = BTreeIndex::load( - test_store, - None, - LanceCache::with_capacity(100 * 1024 * 1024), - ) - .await - .unwrap(); + // The cleanup function should handle both valid and invalid file patterns gracefully + // This test mainly verifies that the function doesn't panic and handles edge cases + super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; - let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0))); - let metrics = LocalMetricsCollector::default(); - let query1 = index.search(&query, &metrics); - let query2 = index.search(&query, &metrics); - tokio::join!(query1, query2).0.unwrap(); - assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); + // If we get here without panicking, the cleanup function handled all cases correctly + assert!(true); } } diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 4edb1cb6a0a..a4506020782 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -163,6 +163,7 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -170,7 +171,8 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { source: "must provide training request created by new_training_request".into(), location: location!(), })?; - Self::train_inverted_index(data, index_store, request.parameters.clone(), None).await + Self::train_inverted_index(data, index_store, request.parameters.clone(), fragment_ids) + .await } /// Load an index from storage diff --git a/rust/lance-index/src/scalar/json.rs b/rust/lance-index/src/scalar/json.rs index 0b8a43efbe7..e36feaacfc7 100644 --- a/rust/lance-index/src/scalar/json.rs +++ b/rust/lance-index/src/scalar/json.rs @@ -768,6 +768,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() @@ -797,7 +798,7 @@ impl ScalarIndexPlugin for JsonIndexPlugin { )?; let target_index = target_plugin - .train_index(converted_stream, index_store, target_request) + .train_index(converted_stream, index_store, target_request, fragment_ids) .await?; let index_details = crate::pb::JsonIndexDetails { diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index 542aa2bc97a..64e932c47c5 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -398,6 +398,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result { let schema = data.schema(); let field = schema @@ -427,7 +428,7 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { let data = unnest_chunks(data)?; let bitmap_plugin = BitmapIndexPlugin; bitmap_plugin - .train_index(data, index_store, request) + .train_index(data, index_store, request, fragment_ids) .await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pb::LabelListIndexDetails::default()) diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index d8a95de1eeb..4df502ead09 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -368,7 +368,7 @@ pub mod tests { ) .unwrap(); btree_plugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -866,6 +866,7 @@ pub mod tests { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); @@ -911,7 +912,7 @@ pub mod tests { .new_training_request("{}", &Field::new(VALUE_COLUMN_NAME, DataType::Int32, false)) .unwrap(); BitmapIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } @@ -1399,7 +1400,7 @@ pub mod tests { ) .unwrap(); LabelListIndexPlugin - .train_index(data, index_store.as_ref(), request) + .train_index(data, index_store.as_ref(), request, None) .await .unwrap(); } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index ff559dd9292..586b0a4da9a 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -1285,6 +1285,7 @@ impl ScalarIndexPlugin for NGramIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, _request: Box, + _fragment_ids: Option>, ) -> Result { Self::train_ngram_index(data, index_store).await?; Ok(CreatedIndex { diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 022da729f0c..3880aad4dbe 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -119,6 +119,7 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + fragment_ids: Option>, ) -> Result; /// Returns true if the index returns an exact answer (e.g. not AtMost) diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 748ab003863..c9097cbb8cc 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -961,6 +961,7 @@ impl ScalarIndexPlugin for ZoneMapIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + _fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 0742ff7f878..58b94f56318 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -71,6 +71,7 @@ impl BenchmarkFixture { &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, + None, ) .await .unwrap(); diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index cdebc399547..ccae72a4865 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -284,12 +284,11 @@ pub(super) async fn build_scalar_index( training_request.criteria(), None, train, - fragment_ids, + fragment_ids.clone(), ) .await?; - plugin - .train_index(training_data, &index_store, training_request) + .train_index(training_data, &index_store, training_request, fragment_ids) .await } From 966bceb13d3dc102e129fbf703dccadd26a422e9 Mon Sep 17 00:00:00 2001 From: xloya Date: Tue, 9 Sep 2025 16:14:05 +0800 Subject: [PATCH 06/13] fix code --- python/python/lance/dataset.py | 10 +- python/python/tests/test_scalar_index.py | 386 ++++++++---------- python/src/dataset.rs | 119 ++---- rust/lance-index/src/scalar/btree.rs | 43 ++ .../src/scalar/inverted/builder.rs | 37 ++ 5 files changed, 287 insertions(+), 308 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 54bd432201a..f17f88da25a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2734,21 +2734,19 @@ def prewarm_index(self, name: str): def merge_index_metadata( self, index_uuid: str, - index_type: Union[ - Literal["BTREE"], - Literal["INVERTED"], - ], + index_type: str, prefetch_batch: Optional[int] = None, ): """ - Merge an index which not commit at present. + Merge an index which is not commit at present. Parameters ---------- index_uuid: str The uuid of the index which want to merge. - index_type: Literal["BTREE", "INVERTED"] + index_type: str The type of the index. + Only "BTREE" and "INVERTED" are supported now. prefetch_batch: int, optional The number of prefetch batches of sub-page files for merging. Default 1. diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 1b3e496dd39..aa0566ed76c 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -92,6 +92,99 @@ def data_table(indexed_dataset: lance.LanceDataset): return indexed_dataset.scanner().to_table() +@pytest.fixture +def btree_comparison_datasets(tmp_path): + """Setup datasets for B-tree comparison tests""" + # Test configuration + num_fragments = 3 + rows_per_fragment = 10000 + total_rows = num_fragments * rows_per_fragment + + # Create dataset for fragment-level indexing + fragment_ds = generate_multi_fragment_dataset( + tmp_path / "fragment", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + # Create dataset for complete indexing (same data structure) + complete_ds = generate_multi_fragment_dataset( + tmp_path / "complete", + num_fragments=num_fragments, + rows_per_fragment=rows_per_fragment, + ) + + import uuid + + # Build fragment-level B-tree index + fragment_index_id = str(uuid.uuid4()) + fragment_index_name = "fragment_btree_precise_test" + + fragments = fragment_ds.get_fragments() + fragment_ids = [fragment.fragment_id for fragment in fragments] + print(f"Fragment IDs: {fragment_ids}") + + # Create fragment-level indices + for fragment in fragments: + fragment_id = fragment.fragment_id + print(f"Creating B-tree index for fragment {fragment_id}") + + fragment_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=fragment_index_name, + replace=False, + fragment_uuid=fragment_index_id, + fragment_ids=[fragment_id], + ) + + # Merge fragment indices + fragment_ds.merge_index_metadata(fragment_index_id, index_type="BTREE") + + # Create Index object for fragment-based index + from lance.dataset import Index + + field_id = fragment_ds.schema.get_field_index("id") + + fragment_index = Index( + uuid=fragment_index_id, + name=fragment_index_name, + fields=[field_id], + dataset_version=fragment_ds.version, + fragment_ids=set(fragment_ids), + index_version=0, + ) + + # Commit fragment-based index + create_fragment_index_op = lance.LanceOperation.CreateIndex( + new_indices=[fragment_index], + removed_indices=[], + ) + + fragment_ds_committed = lance.LanceDataset.commit( + fragment_ds.uri, + create_fragment_index_op, + read_version=fragment_ds.version, + ) + + # Build complete B-tree index + complete_index_name = f"complete_btree_{uuid.uuid4().hex[:8]}" + complete_ds.create_scalar_index( + column="id", + index_type="BTREE", + name=complete_index_name, + ) + # Reload the dataset to get the indexed version + complete_ds = lance.dataset(complete_ds.uri) + + return { + "fragment_ds": fragment_ds_committed, + "complete_ds": complete_ds, + "rows_per_fragment": rows_per_fragment, + "total_rows": total_rows, + } + + def test_load_indices(indexed_dataset: lance.LanceDataset): indices = indexed_dataset.list_indices() vec_idx = next(idx for idx in indices if idx["type"] == "IVF_PQ") @@ -2999,16 +3092,13 @@ def test_distribute_btree_index_build(tmp_path): import uuid index_id = str(uuid.uuid4()) - print(f"Using index ID: {index_id}") index_name = "btree_multiple_fragment_idx" fragments = ds.get_fragments() fragment_ids = [fragment.fragment_id for fragment in fragments] - print(f"Fragment IDs: {fragment_ids}") for fragment in ds.get_fragments(): fragment_id = fragment.fragment_id - print(f"Creating B-tree index for fragment {fragment_id}") # Create B-tree scalar index for each fragment # Use the same index_name for all fragments (like in FTS test) @@ -3021,13 +3111,8 @@ def test_distribute_btree_index_build(tmp_path): fragment_ids=[fragment_id], ) - # For fragment-level indexing, we expect the method to return successfully - # but not commit the index yet - print(f"Fragment {fragment_id} B-tree index created successfully") - # Merge the B-tree index metadata ds.merge_index_metadata(index_id, index_type="BTREE") - print(ds.uri) # Create an Index object using the new dataclass format from lance.dataset import Index @@ -3057,8 +3142,6 @@ def test_distribute_btree_index_build(tmp_path): read_version=ds.version, ) - print("Successfully committed multiple fragment B-tree index") - # Verify the index was created and is functional indices = ds_committed.list_indices() assert len(indices) > 0, "No indices found after commit" @@ -3083,7 +3166,6 @@ def test_distribute_btree_index_build(tmp_path): columns=["id", "text"], ).to_table() - print(f"Search for id = {test_id} returned {results.num_rows} results") assert results.num_rows > 0, f"No results found for id = {test_id}" # Test range queries across fragments @@ -3092,7 +3174,6 @@ def test_distribute_btree_index_build(tmp_path): columns=["id", "text"], ).to_table() - print(f"Range query returned {results_range.num_rows} results") assert results_range.num_rows > 0, "No results found for range query" # Compare with complete index results to ensure consistency @@ -3131,210 +3212,6 @@ def test_distribute_btree_index_build(tmp_path): ) -def test_btree_precise_query_comparison(tmp_path): - """ - Precise comparison test between fragment-level B-tree index and complete - B-tree index. - This test creates identical datasets and compares query results in detail. - """ - # Test configuration - num_fragments = 3 - rows_per_fragment = 10000 - total_rows = num_fragments * rows_per_fragment - - print( - f"Creating datasets with {num_fragments} fragments," - f" {rows_per_fragment} rows each" - ) - - # Create dataset for fragment-level indexing - fragment_ds = generate_multi_fragment_dataset( - tmp_path / "fragment", - num_fragments=num_fragments, - rows_per_fragment=rows_per_fragment, - ) - - # Create dataset for complete indexing (same data structure) - complete_ds = generate_multi_fragment_dataset( - tmp_path / "complete", - num_fragments=num_fragments, - rows_per_fragment=rows_per_fragment, - ) - - import uuid - - # Build fragment-level B-tree index - fragment_index_id = str(uuid.uuid4()) - fragment_index_name = "fragment_btree_precise_test" - - fragments = fragment_ds.get_fragments() - fragment_ids = [fragment.fragment_id for fragment in fragments] - print(f"Fragment IDs: {fragment_ids}") - - # Create fragment-level indices - for fragment in fragments: - fragment_id = fragment.fragment_id - print(f"Creating B-tree index for fragment {fragment_id}") - - fragment_ds.create_scalar_index( - column="id", - index_type="BTREE", - name=fragment_index_name, - replace=False, - fragment_uuid=fragment_index_id, - fragment_ids=[fragment_id], - ) - - # Merge fragment indices - fragment_ds.merge_index_metadata(fragment_index_id, index_type="BTREE") - - # Create Index object for fragment-based index - from lance.dataset import Index - - field_id = fragment_ds.schema.get_field_index("id") - - fragment_index = Index( - uuid=fragment_index_id, - name=fragment_index_name, - fields=[field_id], - dataset_version=fragment_ds.version, - fragment_ids=set(fragment_ids), - index_version=0, - ) - - # Commit fragment-based index - create_fragment_index_op = lance.LanceOperation.CreateIndex( - new_indices=[fragment_index], - removed_indices=[], - ) - - fragment_ds_committed = lance.LanceDataset.commit( - fragment_ds.uri, - create_fragment_index_op, - read_version=fragment_ds.version, - ) - - # Build complete B-tree index - complete_index_name = "complete_btree_precise_test" - complete_ds.create_scalar_index( - column="id", - index_type="BTREE", - name=complete_index_name, - ) - - print("Both indices created successfully") - - # Detailed query comparison tests - test_cases = [ - # Test 1: Boundary values at fragment edges - {"name": "First value", "filter": "id = 0"}, - {"name": "Fragment 0 last value", "filter": f"id = {rows_per_fragment - 1}"}, - {"name": "Fragment 1 first value", "filter": f"id = {rows_per_fragment}"}, - { - "name": "Fragment 1 last value", - "filter": f"id = {2 * rows_per_fragment - 1}", - }, - {"name": "Fragment 2 first value", "filter": f"id = {2 * rows_per_fragment}"}, - {"name": "Last value", "filter": f"id = {total_rows - 1}"}, - # Test 2: Values in the middle of fragments - {"name": "Fragment 0 middle", "filter": f"id = {rows_per_fragment // 2}"}, - { - "name": "Fragment 1 middle", - "filter": f"id = {rows_per_fragment + rows_per_fragment // 2}", - }, - { - "name": "Fragment 2 middle", - "filter": f"id = {2 * rows_per_fragment + rows_per_fragment // 2}", - }, - # Test 3: Range queries within single fragments - {"name": "Range within fragment 0", "filter": "id >= 10 AND id < 20"}, - { - "name": "Range within fragment 1", - "filter": f"id >= {rows_per_fragment + 10}" - f" AND id < {rows_per_fragment + 20}", - }, - { - "name": "Range within fragment 2", - "filter": f"id >= {2 * rows_per_fragment + 10}" - f" AND id < {2 * rows_per_fragment + 20}", - }, - # Test 4: Range queries spanning multiple fragments - { - "name": "Cross fragment 0-1", - "filter": f"id >= {rows_per_fragment - 5} AND id < {rows_per_fragment + 5}", - }, - { - "name": "Cross fragment 1-2", - "filter": f"id >= {2 * rows_per_fragment - 5}" - f" AND id < {2 * rows_per_fragment + 5}", - }, - { - "name": "Cross all fragments", - "filter": f"id >= {rows_per_fragment // 2} AND" - f" id < {2 * rows_per_fragment + rows_per_fragment // 2}", - }, - # Test 5: Edge cases - {"name": "Non-existent small value", "filter": "id = -1"}, - {"name": "Non-existent large value", "filter": f"id = {total_rows + 100}"}, - {"name": "Large range", "filter": f"id >= 0 AND id < {total_rows}"}, - # Test 6: Comparison operators - {"name": "Less than boundary", "filter": f"id < {rows_per_fragment}"}, - { - "name": "Greater than boundary", - "filter": f"id > {2 * rows_per_fragment - 1}", - }, - {"name": "Less than or equal", "filter": f"id <= {rows_per_fragment + 50}"}, - {"name": "Greater than or equal", "filter": f"id >= {rows_per_fragment + 50}"}, - ] - - print(f"\nRunning {len(test_cases)} detailed comparison tests:") - - for i, test_case in enumerate(test_cases, 1): - test_name = test_case["name"] - filter_expr = test_case["filter"] - - print(f" {i:2d}. Testing {test_name}: {filter_expr}") - - # Query fragment-based index - fragment_results = fragment_ds_committed.scanner( - filter=filter_expr, - columns=["id", "text"], - ).to_table() - - # Query complete index - complete_results = complete_ds.scanner( - filter=filter_expr, - columns=["id", "text"], - ).to_table() - - # Compare row counts - assert fragment_results.num_rows == complete_results.num_rows, ( - f"Test '{test_name}' failed: Fragment index " - f"returned {fragment_results.num_rows} rows, " - f"but complete index returned {complete_results.num_rows}" - f" rows for filter: {filter_expr}" - ) - - # Compare actual results if there are any - if fragment_results.num_rows > 0: - # Sort both results by id for comparison - fragment_ids = sorted(fragment_results.column("id").to_pylist()) - complete_ids = sorted(complete_results.column("id").to_pylist()) - - assert fragment_ids == complete_ids, ( - f"Test '{test_name}' failed: Fragment index" - f" returned different IDs than complete index. " - f"Fragment IDs:" - f" {fragment_ids[:10]}{'...' if len(fragment_ids) > 10 else ''}, " - f"Complete IDs:" - f" {complete_ids[:10]}{'...' if len(complete_ids) > 10 else ''}" - ) - - print(f"Passed ({fragment_results.num_rows} rows)") - - print(f"All {len(test_cases)} precision tests passed.") - - def test_btree_fragment_ids_parameter_validation(tmp_path): """ Test validation of fragment_ids parameter for B-tree indices. @@ -3391,3 +3268,80 @@ def test_btree_backward_compatibility_no_fragment_ids(tmp_path): # Test that the index works results = ds.scanner(filter="id = 50").to_table() assert results.num_rows > 0 + + +@pytest.mark.parametrize( + "test_name,filter_expr", + [ + # Test 1: Boundary values at fragment edges + ("First value", "id = 0"), + ("Fragment 0 last value", "id = 9999"), + ("Fragment 1 first value", "id = 10000"), + ("Fragment 1 last value", "id = 19999"), + ("Fragment 2 first value", "id = 20000"), + ("Last value", "id = 29999"), + # Test 2: Values in the middle of fragments + ("Fragment 0 middle", "id = 5000"), + ("Fragment 1 middle", "id = 15000"), + ("Fragment 2 middle", "id = 25000"), + # Test 3: Range queries within single fragments + ("Range within fragment 0", "id >= 10 AND id < 20"), + ("Range within fragment 1", "id >= 10010 AND id < 10020"), + ("Range within fragment 2", "id >= 20010 AND id < 20020"), + # Test 4: Range queries spanning multiple fragments + ("Cross fragment 0-1", "id >= 9995 AND id < 10005"), + ("Cross fragment 1-2", "id >= 19995 AND id < 20005"), + ("Cross all fragments", "id >= 5000 AND id < 25000"), + # Test 5: Edge cases + ("Non-existent small value", "id = -1"), + ("Non-existent large value", "id = 30100"), + ("Large range", "id >= 0 AND id < 30000"), + # Test 6: Comparison operators + ("Less than boundary", "id < 10000"), + ("Greater than boundary", "id > 19999"), + ("Less than or equal", "id <= 10050"), + ("Greater than or equal", "id >= 10050"), + ], +) +def test_btree_query_comparison_parametrized( + btree_comparison_datasets, test_name, filter_expr +): + """ + Parametrized B-tree index query comparison test + + Convert the original loop test to parametrized test, + each test case runs independently + """ + fragment_ds = btree_comparison_datasets["fragment_ds"] + complete_ds = btree_comparison_datasets["complete_ds"] + + # Query fragment-based index + fragment_results = fragment_ds.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Query complete index + complete_results = complete_ds.scanner( + filter=filter_expr, + columns=["id", "text"], + ).to_table() + + # Compare row counts + assert fragment_results.num_rows == complete_results.num_rows, ( + f"Test '{test_name}' failed: Fragment index " + f"returned {fragment_results.num_rows} rows, " + f"but complete index returned {complete_results.num_rows}" + f" rows for filter: {filter_expr}" + ) + + # Compare actual results if there are any + if fragment_results.num_rows > 0: + # Sort both results by id for comparison + fragment_ids = sorted(fragment_results.column("id").to_pylist()) + complete_ids = sorted(complete_results.column("id").to_pylist()) + + assert fragment_ids == complete_ids, ( + f"Test '{test_name}' failed: Fragment index " + f"and complete index returned different results for filter: {filter_expr}" + ) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index f3d0fd10e83..e581c3c60c1 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1678,103 +1678,50 @@ impl Dataset { prefetch_batch: Option, ) -> PyResult<()> { RT.block_on(None, async { - let index_type = index_type.to_uppercase(); - let idx_type = match index_type.as_str() { - "BTREE" => IndexType::BTree, - "INVERTED" => IndexType::Inverted, - _ => { - return Err(Error::InvalidInput { - source: format!( - "Index type {} is not supported.", - index_type - ).into(), - location: location!(), - }); - } - }; - let store = LanceIndexStore::from_dataset_for_new(self.ds.as_ref(), index_uuid)?; let index_dir = self.ds.indices_dir().child(index_uuid); - if idx_type == IndexType::Inverted { - // List all partition metadata files in the index directory - let mut part_metadata_files = Vec::new(); - let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); - - while let Some(item) = list_stream.next().await { - match item { - Ok(meta) => { - let file_name = meta.location.filename().unwrap_or_default(); - // Filter files matching the pattern part_*_metadata.lance - if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") - { - part_metadata_files.push(file_name.to_string()); - } - } - Err(_) => continue, - } - } - - if part_metadata_files.is_empty() { - return Err(Error::InvalidInput { - source: format!( - "No partition metadata files found in index directory: {}", - index_dir + match index_type.to_uppercase().as_str() { + "INVERTED" => { + // List all partition metadata files in the index directory + let part_metadata_files = + lance_index::scalar::inverted::builder::list_metadata_files( + self.ds.object_store(), + &index_dir, ) - .into(), - location: location!(), - }); - } + .await?; - // Call merge_metadata_files function for inverted index - lance_index::scalar::inverted::builder::merge_metadata_files( - Arc::new(store), - &part_metadata_files, - ) + // Call merge_metadata_files function for inverted index + lance_index::scalar::inverted::builder::merge_metadata_files( + Arc::new(store), + &part_metadata_files, + ) .await - } else { - // List all partition page / lookup files in the index directory - let mut part_page_files = Vec::new(); - let mut part_lookup_files = Vec::new(); - let mut list_stream = self.ds.object_store().list(Some(index_dir.clone())); - - while let Some(item) = list_stream.next().await { - match item { - Ok(meta) => { - let file_name = meta.location.filename().unwrap_or_default(); - // Filter files matching the pattern part_*_metadata.lance - if file_name.starts_with("part_") && file_name.ends_with("_page_data.lance") - { - part_page_files.push(file_name.to_string()); - } - if file_name.starts_with("part_") && file_name.ends_with("_page_lookup.lance") - { - part_lookup_files.push(file_name.to_string()); - } - } - Err(_) => continue, - } } - if part_page_files.is_empty() || part_lookup_files.is_empty() { - return Err(Error::InvalidInput { - source: format!( - "No partition metadata files found in index directory: {} (page_files: {}, lookup_files: {})", - index_dir, part_page_files.len(), part_lookup_files.len() + "BTREE" => { + // List all partition page / lookup files in the index directory + let (part_page_files, part_lookup_files) = + lance_index::scalar::btree::list_page_lookup_files( + self.ds.object_store(), + &index_dir, ) - .into(), + .await?; + + // Call merge_metadata_files function for btree index + lance_index::scalar::btree::merge_metadata_files( + Arc::new(store), + &part_page_files, + &part_lookup_files, + prefetch_batch, + ) + .await + } + _ => { + return Err(Error::InvalidInput { + source: format!("Index type {} is not supported.", index_type).into(), location: location!(), }); } - - // Call merge_metadata_files function for btree index - lance_index::scalar::btree::merge_metadata_files( - Arc::new(store), - &part_page_files, - &part_lookup_files, - prefetch_batch, - ).await } - - })? .map_err(|err| PyValueError::new_err(err.to_string())) } diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index cad8af860c6..01c1e843c67 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -54,7 +54,9 @@ use lance_datafusion::{ chunker::chunk_concat_stream, exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; +use lance_io::object_store::ObjectStore; use log::{debug, warn}; +use object_store::path::Path; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize, Serializer}; use snafu::location; @@ -1460,6 +1462,47 @@ fn extract_partition_id(filename: &str) -> Result { }) } +/// List and filter files from the index directory +/// Returns (page_files, lookup_files) +pub async fn list_page_lookup_files( + object_store: &ObjectStore, + index_dir: &Path, +) -> Result<(Vec, Vec)> { + let mut part_page_files = Vec::new(); + let mut part_lookup_files = Vec::new(); + + let mut list_stream = object_store.list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_page_data.lance + if file_name.starts_with("part_") && file_name.ends_with("_page_data.lance") { + part_page_files.push(file_name.to_string()); + } + // Filter files matching the pattern part_*_page_lookup.lance + if file_name.starts_with("part_") && file_name.ends_with("_page_lookup.lance") { + part_lookup_files.push(file_name.to_string()); + } + } + Err(_) => continue, + } + } + + if part_page_files.is_empty() || part_lookup_files.is_empty() { + return Err(Error::Internal { + message: format!( + "No partition metadata files found in index directory: {} (page_files: {}, lookup_files: {})", + index_dir, part_page_files.len(), part_lookup_files.len() + ), + location: location!(), + }); + } + + Ok((part_page_files, part_lookup_files)) +} + /// Merge multiple partition page / lookup files into a complete metadata file /// /// In a distributed environment, each worker node writes partition page / lookup files for the partitions it processes, diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 4731f63712c..c9efe1bf801 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -788,6 +788,43 @@ pub(crate) fn part_metadata_file_path(partition_id: u64) -> String { format!("part_{}_{}", partition_id, METADATA_FILE) } +/// List and filter metadata files from the index directory +/// Returns partition metadata files +pub async fn list_metadata_files( + object_store: &ObjectStore, + index_dir: &Path, +) -> Result> { + // List all partition metadata files in the index directory + let mut part_metadata_files = Vec::new(); + let mut list_stream = object_store.list(Some(index_dir.clone())); + + while let Some(item) = list_stream.next().await { + match item { + Ok(meta) => { + let file_name = meta.location.filename().unwrap_or_default(); + // Filter files matching the pattern part_*_metadata.lance + if file_name.starts_with("part_") && file_name.ends_with("_metadata.lance") { + part_metadata_files.push(file_name.to_string()); + } + } + Err(_) => continue, + } + } + + if part_metadata_files.is_empty() { + return Err(Error::InvalidInput { + source: format!( + "No partition metadata files found in index directory: {}", + index_dir + ) + .into(), + location: location!(), + }); + } + + Ok(part_metadata_files) +} + /// Merge partition metadata files with partition ID remapping to sequential IDs starting from 0 pub async fn merge_metadata_files( store: Arc, From bcd5b3e33774503ef8a0bda0e880134701e6a48b Mon Sep 17 00:00:00 2001 From: xloya Date: Tue, 9 Sep 2025 19:02:26 +0800 Subject: [PATCH 07/13] fix clippy --- python/src/dataset.rs | 10 ++++------ rust/lance-index/src/scalar/btree.rs | 7 ++----- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index e581c3c60c1..a3ab5b24a56 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1715,12 +1715,10 @@ impl Dataset { ) .await } - _ => { - return Err(Error::InvalidInput { - source: format!("Index type {} is not supported.", index_type).into(), - location: location!(), - }); - } + _ => Err(Error::InvalidInput { + source: format!("Index type {} is not supported.", index_type).into(), + location: location!(), + }), } })? .map_err(|err| PyValueError::new_err(err.to_string())) diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 01c1e843c67..621138b6b04 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -2562,7 +2562,7 @@ mod tests { // Method 1: Build complete index directly using the same data // Create deterministic data for comparison - use 2 * DEFAULT_BTREE_BATCH_SIZE for testing - let total_count = (2 * DEFAULT_BTREE_BATCH_SIZE) as u64; + let total_count = 2 * DEFAULT_BTREE_BATCH_SIZE; let full_data_gen = gen_batch() .col("value", array::step::()) .col("_rowid", array::step::()) @@ -2745,7 +2745,7 @@ mod tests { let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); // Use 3 * DEFAULT_BTREE_BATCH_SIZE for more comprehensive boundary testing - let total_count = (3 * DEFAULT_BTREE_BATCH_SIZE) as u64; + let total_count = 3 * DEFAULT_BTREE_BATCH_SIZE; // Method 1: Build complete index directly let full_data_gen = gen_batch() @@ -3202,9 +3202,6 @@ mod tests { // The cleanup function should handle both valid and invalid file patterns gracefully // This test mainly verifies that the function doesn't panic and handles edge cases super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; - - // If we get here without panicking, the cleanup function handled all cases correctly - assert!(true); } #[test] From eaaac6bd84a8d065fa3ac4ab7be07e9bba288068 Mon Sep 17 00:00:00 2001 From: xloya Date: Tue, 9 Sep 2025 19:07:18 +0800 Subject: [PATCH 08/13] update --- rust/lance-index/src/scalar/btree.rs | 471 +++++++++++++-------------- 1 file changed, 230 insertions(+), 241 deletions(-) diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 621138b6b04..34c7688e5d7 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -66,7 +66,6 @@ const BTREE_LOOKUP_NAME: &str = "page_lookup.lance"; const BTREE_PAGES_NAME: &str = "page_data.lance"; pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; const BATCH_SIZE_META_KEY: &str = "batch_size"; - const BTREE_INDEX_VERSION: u32 = 0; pub(crate) const BTREE_VALUES_COLUMN: &str = "values"; pub(crate) const BTREE_IDS_COLUMN: &str = "ids"; @@ -1438,30 +1437,6 @@ pub async fn train_btree_index( Ok(()) } -/// Extract partition ID from partition file name -/// Expected format: "part_{partition_id}_{suffix}.lance" -fn extract_partition_id(filename: &str) -> Result { - if !filename.starts_with("part_") { - return Err(Error::Internal { - message: format!("Invalid partition file name format: {}", filename), - location: location!(), - }); - } - - let parts: Vec<&str> = filename.split('_').collect(); - if parts.len() < 3 { - return Err(Error::Internal { - message: format!("Invalid partition file name format: {}", filename), - location: location!(), - }); - } - - parts[1].parse::().map_err(|_| Error::Internal { - message: format!("Failed to parse partition ID from filename: {}", filename), - location: location!(), - }) -} - /// List and filter files from the index directory /// Returns (page_files, lookup_files) pub async fn list_page_lookup_files( @@ -1570,6 +1545,7 @@ pub async fn merge_metadata_files( prefetch_config = prefetch_config.with_prefetch_batch(batch); } + // Step 4: Merge pages and create lookup entries let lookup_entries = merge_page( part_lookup_files, &page_files_map, @@ -1583,7 +1559,7 @@ pub async fn merge_metadata_files( page_file.finish().await?; - // Step 4: Generate new lookup file based on reorganized pages + // Step 5: Generate new lookup file based on reorganized pages // Add batch_size to schema metadata let mut metadata = HashMap::new(); metadata.insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); @@ -1627,6 +1603,225 @@ pub async fn merge_metadata_files( Ok(()) } +async fn merge_page( + part_lookup_files: &[String], + page_files_map: &HashMap, + store: &Arc, + batch_size: u64, + page_file: &mut Box, + arrow_schema: Arc, + prefetch_config: PrefetchConfig, +) -> Result> { + let mut lookup_entries = Vec::new(); + let mut page_idx = 0u32; + + let start_time = std::time::Instant::now(); + debug!( + "Starting multi-way merge with {} partitions using prefetch manager", + part_lookup_files.len() + ); + + // Directly create iterators and read first element + let mut partition_map = HashMap::new(); + let mut heap = BinaryHeap::new(); + + debug!("Initializing {} partitions", part_lookup_files.len()); + + // Initialize all partitions + for lookup_file in part_lookup_files { + let partition_id = extract_partition_id(lookup_file)?; + let page_file_name = + (*page_files_map + .get(&partition_id) + .ok_or_else(|| Error::Internal { + message: format!("Page file not found for partition ID: {}", partition_id), + location: location!(), + })?) + .clone(); + + let mut iterator = PartitionIterator::new( + store.clone(), + page_file_name, + batch_size, + prefetch_config.clone(), + ) + .await?; + + let first_element = iterator.next().await?; + + if let Some((value, row_id)) = first_element { + // Put the first element into the heap with cached orderable value + let orderable_value = OrderableScalarValue(value.clone()); + heap.push(HeapElement { + value, + row_id, + partition_id, + orderable_value, + }); + } + + partition_map.insert(partition_id, iterator); + } + + debug!( + "Initialized {} partitions, heap size: {}", + partition_map.len(), + heap.len() + ); + + let mut current_batch_rows = Vec::with_capacity(batch_size as usize); + let mut total_merged = 0usize; + + // Multi-way merge main loop + while let Some(min_element) = heap.pop() { + // Add current minimum element to batch + current_batch_rows.push((min_element.value, min_element.row_id)); + total_merged += 1; + + // Read next element from corresponding partition + if let Some(iterator) = partition_map.get_mut(&min_element.partition_id) { + if let Some((next_value, next_row_id)) = iterator.next().await? { + let orderable_value = OrderableScalarValue(next_value.clone()); + heap.push(HeapElement { + value: next_value, + row_id: next_row_id, + partition_id: min_element.partition_id, + orderable_value, + }); + } + } + + if current_batch_rows.len() >= batch_size as usize { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + } + + if !current_batch_rows.is_empty() { + write_batch_and_lookup_entry( + &mut current_batch_rows, + page_file, + &arrow_schema, + &mut lookup_entries, + &mut page_idx, + ) + .await?; + } + + let elapsed = start_time.elapsed(); + let rows_per_second = if elapsed.as_secs_f64() > 0.0 { + total_merged as f64 / elapsed.as_secs_f64() + } else { + 0.0 + }; + + debug!( + "Completed multi-way merge: merged {} rows into {} lookup entries in {:.2}s ({:.0} rows/s)", + total_merged, + lookup_entries.len(), + elapsed.as_secs_f64(), + rows_per_second + ); + Ok(lookup_entries) +} + +/// Helper function to prepare batch data with optimized memory usage +async fn prepare_batch_data( + batch_rows: Vec<(ScalarValue, ScalarValue)>, + arrow_schema: Arc, + page_idx: u32, +) -> Result<(RecordBatch, (ScalarValue, ScalarValue, u32, u32))> { + if batch_rows.is_empty() { + return Err(Error::Internal { + message: "Cannot prepare empty batch".to_string(), + location: location!(), + }); + } + + let capacity = batch_rows.len(); + let mut values = Vec::with_capacity(capacity); + let mut row_ids = Vec::with_capacity(capacity); + + for (value, row_id) in batch_rows.into_iter() { + values.push(value); + row_ids.push(row_id); + } + + let (values_array, row_ids_array) = rayon::join( + || ScalarValue::iter_to_array(values.into_iter()), + || ScalarValue::iter_to_array(row_ids.into_iter()), + ); + + let values_array = values_array?; + let row_ids_array = row_ids_array?; + + let batch = RecordBatch::try_new(arrow_schema, vec![values_array, row_ids_array])?; + + // Calculate min/max/null_count for lookup entry + let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; + let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; + let null_count = batch.column(0).null_count() as u32; + + let lookup_entry = (min_val, max_val, null_count, page_idx); + + Ok((batch, lookup_entry)) +} + +/// Helper function to write a batch and create lookup entry +async fn write_batch_and_lookup_entry( + batch_rows: &mut Vec<(ScalarValue, ScalarValue)>, + page_file: &mut Box, + arrow_schema: &Arc, + lookup_entries: &mut Vec<(ScalarValue, ScalarValue, u32, u32)>, + page_idx: &mut u32, +) -> Result<()> { + if batch_rows.is_empty() { + return Ok(()); + } + + let batch_data = std::mem::take(batch_rows); + let current_page_idx = *page_idx; + + let (batch, lookup_entry) = + prepare_batch_data(batch_data, arrow_schema.clone(), current_page_idx).await?; + + lookup_entries.push(lookup_entry); + page_file.write_record_batch(batch).await?; + *page_idx += 1; + + Ok(()) +} + +/// Extract partition ID from partition file name +/// Expected format: "part_{partition_id}_{suffix}.lance" +fn extract_partition_id(filename: &str) -> Result { + if !filename.starts_with("part_") { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + let parts: Vec<&str> = filename.split('_').collect(); + if parts.len() < 3 { + return Err(Error::Internal { + message: format!("Invalid partition file name format: {}", filename), + location: location!(), + }); + } + + parts[1].parse::().map_err(|_| Error::Internal { + message: format!("Failed to parse partition ID from filename: {}", filename), + location: location!(), + }) +} + /// Clean up partition files after successful merge /// /// This function safely deletes partition lookup and page files after a successful merge operation. @@ -1671,15 +1866,12 @@ async fn cleanup_single_file( expected_suffix: &str, file_type: &str, ) { - // Ensure we only delete files that match the expected pattern (safety check) if file_name.starts_with(expected_prefix) && file_name.ends_with(expected_suffix) { match store.delete_index_file(file_name).await { Ok(()) => { debug!("Successfully deleted {} file: {}", file_type, file_name); } Err(e) => { - // File deletion failures should not affect the overall success of the function - // Log the error but continue processing other files warn!( "Failed to delete {} file '{}': {}. \ This does not affect the merge operation, but may leave \ @@ -1698,6 +1890,14 @@ async fn cleanup_single_file( } } +pub(crate) fn part_page_data_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_PAGES_NAME) +} + +pub(crate) fn part_lookup_file_path(partition_id: u64) -> String { + format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME) +} + /// Prefetch configuration for partition iterators #[derive(Debug, Clone)] pub struct PrefetchConfig { @@ -1932,8 +2132,6 @@ impl Eq for HeapElement {} impl PartialOrd for HeapElement { fn partial_cmp(&self, other: &Self) -> Option { - // Note: BinaryHeap is a maximum heap, we need a minimum heap, - // so reverse the comparison result Some(self.cmp(other)) } } @@ -1946,214 +2144,6 @@ impl Ord for HeapElement { } } -async fn merge_page( - part_lookup_files: &[String], - page_files_map: &HashMap, - store: &Arc, - batch_size: u64, - page_file: &mut Box, - arrow_schema: Arc, - prefetch_config: PrefetchConfig, -) -> Result> { - let mut lookup_entries = Vec::new(); - let mut page_idx = 0u32; - - let start_time = std::time::Instant::now(); - debug!( - "Starting multi-way merge with {} partitions using prefetch manager", - part_lookup_files.len() - ); - - // Directly create iterators and read first element - let mut partition_map = HashMap::new(); - let mut heap = BinaryHeap::new(); - - debug!("Initializing {} partitions", part_lookup_files.len()); - - // Initialize all partitions - for lookup_file in part_lookup_files { - let partition_id = extract_partition_id(lookup_file)?; - let page_file_name = - (*page_files_map - .get(&partition_id) - .ok_or_else(|| Error::Internal { - message: format!("Page file not found for partition ID: {}", partition_id), - location: location!(), - })?) - .clone(); - - let mut iterator = PartitionIterator::new( - store.clone(), - page_file_name, - batch_size, - prefetch_config.clone(), - ) - .await?; - - let first_element = iterator.next().await?; - - if let Some((value, row_id)) = first_element { - // Put the first element into the heap with cached orderable value - let orderable_value = OrderableScalarValue(value.clone()); - heap.push(HeapElement { - value, - row_id, - partition_id, - orderable_value, - }); - } - - partition_map.insert(partition_id, iterator); - } - - debug!( - "Initialized {} partitions, heap size: {}", - partition_map.len(), - heap.len() - ); - - let mut current_batch_rows = Vec::with_capacity(batch_size as usize); - let mut total_merged = 0usize; - - // Multi-way merge main loop - while let Some(min_element) = heap.pop() { - // Add current minimum element to batch - current_batch_rows.push((min_element.value, min_element.row_id)); - total_merged += 1; - - // Read next element from corresponding partition - if let Some(iterator) = partition_map.get_mut(&min_element.partition_id) { - if let Some((next_value, next_row_id)) = iterator.next().await? { - let orderable_value = OrderableScalarValue(next_value.clone()); - heap.push(HeapElement { - value: next_value, - row_id: next_row_id, - partition_id: min_element.partition_id, - orderable_value, - }); - } - } - - // Write when batch reaches specified size - if current_batch_rows.len() >= batch_size as usize { - write_batch_and_lookup_entry( - &mut current_batch_rows, - page_file, - &arrow_schema, - &mut lookup_entries, - &mut page_idx, - ) - .await?; - } - } - - // Write the remaining data - if !current_batch_rows.is_empty() { - write_batch_and_lookup_entry( - &mut current_batch_rows, - page_file, - &arrow_schema, - &mut lookup_entries, - &mut page_idx, - ) - .await?; - } - - let elapsed = start_time.elapsed(); - let rows_per_second = if elapsed.as_secs_f64() > 0.0 { - total_merged as f64 / elapsed.as_secs_f64() - } else { - 0.0 - }; - - debug!( - "Completed multi-way merge: merged {} rows into {} lookup entries in {:.2}s ({:.0} rows/s)", - total_merged, - lookup_entries.len(), - elapsed.as_secs_f64(), - rows_per_second - ); - Ok(lookup_entries) -} - -/// Helper function to prepare batch data with optimized memory usage -async fn prepare_batch_data( - batch_rows: Vec<(ScalarValue, ScalarValue)>, - arrow_schema: Arc, - page_idx: u32, -) -> Result<(RecordBatch, (ScalarValue, ScalarValue, u32, u32))> { - if batch_rows.is_empty() { - return Err(Error::Internal { - message: "Cannot prepare empty batch".to_string(), - location: location!(), - }); - } - - // Pre-allocate vectors with exact capacity to avoid reallocations - let capacity = batch_rows.len(); - let mut values = Vec::with_capacity(capacity); - let mut row_ids = Vec::with_capacity(capacity); - - // Unzip with pre-allocated vectors - for (value, row_id) in batch_rows.into_iter() { - values.push(value); - row_ids.push(row_id); - } - - // Convert to arrays in parallel - let (values_array, row_ids_array) = rayon::join( - || ScalarValue::iter_to_array(values.into_iter()), - || ScalarValue::iter_to_array(row_ids.into_iter()), - ); - - let values_array = values_array?; - let row_ids_array = row_ids_array?; - - let batch = RecordBatch::try_new(arrow_schema, vec![values_array, row_ids_array])?; - - // Calculate min/max/null_count for lookup entry - let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; - let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; - let null_count = batch.column(0).null_count() as u32; - - let lookup_entry = (min_val, max_val, null_count, page_idx); - - Ok((batch, lookup_entry)) -} - -/// Helper function to write a batch and create lookup entry -async fn write_batch_and_lookup_entry( - batch_rows: &mut Vec<(ScalarValue, ScalarValue)>, - page_file: &mut Box, - arrow_schema: &Arc, - lookup_entries: &mut Vec<(ScalarValue, ScalarValue, u32, u32)>, - page_idx: &mut u32, -) -> Result<()> { - if batch_rows.is_empty() { - return Ok(()); - } - - let batch_data = std::mem::take(batch_rows); - let current_page_idx = *page_idx; - - let (batch, lookup_entry) = - prepare_batch_data(batch_data, arrow_schema.clone(), current_page_idx).await?; - - lookup_entries.push(lookup_entry); - page_file.write_record_batch(batch).await?; - *page_idx += 1; - - Ok(()) -} - -pub(crate) fn part_page_data_file_path(partition_id: u64) -> String { - format!("part_{}_{}", partition_id, BTREE_PAGES_NAME) -} - -pub(crate) fn part_lookup_file_path(partition_id: u64) -> String { - format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME) -} - /// A stream that reads the original training data back out of the index /// /// This is used for updating the index @@ -2354,7 +2344,6 @@ mod tests { OrderableScalarValue, DEFAULT_BTREE_BATCH_SIZE, }; - // Additional imports for new tests use datafusion::arrow::array::Int32Array; use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; @@ -2540,7 +2529,6 @@ mod tests { assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1); } - /// Test that fragment-based btree index construction produces exactly the same results as building a complete index #[tokio::test] async fn test_fragment_btree_index_consistency() { // Setup stores for both indexes @@ -3281,6 +3269,7 @@ mod tests { let second_with_null = heap_with_null.pop().unwrap(); assert_eq!(second_with_null.value, ScalarValue::Int32(Some(1))); } + #[tokio::test] async fn test_partition_iterator() { // Create test environment From fc9aaf743cb306c37a381aff618f9c6165191ac6 Mon Sep 17 00:00:00 2001 From: xloya Date: Thu, 11 Sep 2025 14:07:11 +0800 Subject: [PATCH 09/13] update code --- python/python/lance/dataset.py | 6 +- python/python/lance/lance/__init__.pyi | 2 +- python/python/tests/test_scalar_index.py | 43 +- python/src/dataset.rs | 37 +- rust/lance-index/src/scalar.rs | 16 +- rust/lance-index/src/scalar/bitmap.rs | 1 + rust/lance-index/src/scalar/btree.rs | 643 ++++-------------- .../src/scalar/inverted/builder.rs | 19 +- rust/lance-index/src/scalar/lance_format.rs | 49 +- rust/lance-index/src/scalar/ngram.rs | 1 + rust/lance-index/src/scalar/zonemap.rs | 1 + 11 files changed, 215 insertions(+), 603 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index f17f88da25a..a68e544febc 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2735,7 +2735,7 @@ def merge_index_metadata( self, index_uuid: str, index_type: str, - prefetch_batch: Optional[int] = None, + batch_readhead: Optional[int] = None, ): """ Merge an index which is not commit at present. @@ -2747,7 +2747,7 @@ def merge_index_metadata( index_type: str The type of the index. Only "BTREE" and "INVERTED" are supported now. - prefetch_batch: int, optional + batch_readhead: int, optional The number of prefetch batches of sub-page files for merging. Default 1. """ @@ -2762,7 +2762,7 @@ def merge_index_metadata( f"merge index metadata. Received {index_type}", ) ) - return self._ds.merge_index_metadata(index_uuid, index_type, prefetch_batch) + return self._ds.merge_index_metadata(index_uuid, index_type, batch_readhead) def session(self) -> Session: """ diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 0bae8e2f1aa..9e7bd55c395 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -283,7 +283,7 @@ class _Dataset: def drop_index(self, name: str): ... def prewarm_index(self, name: str): ... def merge_index_metadata( - self, index_uuid: str, index_type: str, prefetch_batch: Optional[int] = None + self, index_uuid: str, index_type: str, batch_readhead: Optional[int] = None ): ... def count_fragments(self) -> int: ... def num_small_files(self, max_rows_per_group: int) -> int: ... diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index aa0566ed76c..0049130514d 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -122,12 +122,10 @@ def btree_comparison_datasets(tmp_path): fragments = fragment_ds.get_fragments() fragment_ids = [fragment.fragment_id for fragment in fragments] - print(f"Fragment IDs: {fragment_ids}") # Create fragment-level indices for fragment in fragments: fragment_id = fragment.fragment_id - print(f"Creating B-tree index for fragment {fragment_id}") fragment_ds.create_scalar_index( column="id", @@ -3111,6 +3109,18 @@ def test_distribute_btree_index_build(tmp_path): fragment_ids=[fragment_id], ) + # test that the dataset should be searchable + # when the index not committed yet + # Test that the index works for searching + # Test exact equality queries + test_id = 100 # Should be in first fragment + results = ds.scanner( + filter=f"id = {test_id}", + columns=["id", "text"], + ).to_table() + + assert results.num_rows == 1, f"No results found for id = {test_id}" + # Merge the B-tree index metadata ds.merge_index_metadata(index_id, index_type="BTREE") @@ -3166,7 +3176,7 @@ def test_distribute_btree_index_build(tmp_path): columns=["id", "text"], ).to_table() - assert results.num_rows > 0, f"No results found for id = {test_id}" + assert results.num_rows == 1, f"No results found for id = {test_id}" # Test range queries across fragments results_range = ds_committed.scanner( @@ -3243,33 +3253,6 @@ def test_btree_fragment_ids_parameter_validation(tmp_path): print(f"Expected error for invalid fragment ID: {e}") -def test_btree_backward_compatibility_no_fragment_ids(tmp_path): - """ - Test that B-tree indexing remains backward compatible - when fragment_ids is not provided. - """ - ds = generate_multi_fragment_dataset( - tmp_path, num_fragments=2, rows_per_fragment=10000 - ) - - # This should work exactly as before (full dataset indexing) - ds.create_scalar_index( - column="id", - index_type="BTREE", - name="full_dataset_btree_idx", - ) - - # Verify the index was created - indices = ds.list_indices() - assert len(indices) == 1 - assert indices[0]["name"] == "full_dataset_btree_idx" - assert indices[0]["type"] == "BTree" - - # Test that the index works - results = ds.scanner(filter="id = 50").to_table() - assert results.num_rows > 0 - - @pytest.mark.parametrize( "test_name,filter_expr", [ diff --git a/python/src/dataset.rs b/python/src/dataset.rs index a3ab5b24a56..800eec9a264 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1670,48 +1670,33 @@ impl Dataset { .infer_error() } - #[pyo3(signature = (index_uuid, index_type, prefetch_batch))] + #[pyo3(signature = (index_uuid, index_type, batch_readhead))] fn merge_index_metadata( &self, index_uuid: &str, index_type: &str, - prefetch_batch: Option, + batch_readhead: Option, ) -> PyResult<()> { RT.block_on(None, async { let store = LanceIndexStore::from_dataset_for_new(self.ds.as_ref(), index_uuid)?; let index_dir = self.ds.indices_dir().child(index_uuid); match index_type.to_uppercase().as_str() { "INVERTED" => { - // List all partition metadata files in the index directory - let part_metadata_files = - lance_index::scalar::inverted::builder::list_metadata_files( - self.ds.object_store(), - &index_dir, - ) - .await?; - - // Call merge_metadata_files function for inverted index - lance_index::scalar::inverted::builder::merge_metadata_files( + // Call merge_index_files function for inverted index + lance_index::scalar::inverted::builder::merge_index_files( + self.ds.object_store(), + &index_dir, Arc::new(store), - &part_metadata_files, ) .await } "BTREE" => { - // List all partition page / lookup files in the index directory - let (part_page_files, part_lookup_files) = - lance_index::scalar::btree::list_page_lookup_files( - self.ds.object_store(), - &index_dir, - ) - .await?; - - // Call merge_metadata_files function for btree index - lance_index::scalar::btree::merge_metadata_files( + // Call merge_index_files function for btree index + lance_index::scalar::btree::merge_index_files( + self.ds.object_store(), + &index_dir, Arc::new(store), - &part_page_files, - &part_lookup_files, - prefetch_batch, + batch_readhead, ) .await } diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index ff934a9c811..5ac0a0fce10 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -3,10 +3,6 @@ //! Scalar indices for metadata search & filtering -use std::collections::{HashMap, HashSet}; -use std::fmt::Debug; -use std::{any::Any, ops::Bound, sync::Arc}; - use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow_array::{ListArray, RecordBatch}; use arrow_schema::{Field, Schema}; @@ -15,6 +11,10 @@ use datafusion::functions::string::contains::ContainsFunc; use datafusion::functions_array::array_has; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::{scalar::ScalarValue, Column}; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::pin::Pin; +use std::{any::Any, ops::Bound, sync::Arc}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; @@ -44,6 +44,7 @@ pub mod zonemap; use crate::frag_reuse::FragReuseIndex; pub use inverted::tokenizer::InvertedIndexParams; use lance_datafusion::udf::CONTAINS_TOKENS_UDF; +use lance_io::stream::RecordBatchStream; pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index"; @@ -189,6 +190,13 @@ pub trait IndexReader: Send + Sync { range: std::ops::Range, projection: Option<&[&str]>, ) -> Result; + /// Reads data from the file as a stream of record batches + async fn read_stream( + &self, + batch_size: u32, + batch_readahead: u32, + projection: Option<&[&str]>, + ) -> Result>>; /// Return the number of batches in the file async fn num_batches(&self, batch_size: u64) -> u32; /// Return the number of rows in the file diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 09dde5297b4..aad4c93abd6 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -530,6 +530,7 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { _request: Box, _fragment_ids: Option>, ) -> Result { + assert!(_fragment_ids.is_none()); Self::train_bitmap_index(data, index_store).await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pb::BitmapIndexDetails::default()).unwrap(), diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 34c7688e5d7..30455453d58 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -4,7 +4,7 @@ use std::{ any::Any, cmp::Ordering, - collections::{BTreeMap, BinaryHeap, HashMap, VecDeque}, + collections::{BTreeMap, BinaryHeap, HashMap}, fmt::{Debug, Display}, ops::Bound, sync::Arc, @@ -1437,9 +1437,21 @@ pub async fn train_btree_index( Ok(()) } +pub async fn merge_index_files( + object_store: &ObjectStore, + index_dir: &Path, + store: Arc, + batch_readhead: Option, +) -> Result<()> { + // List all partition page / lookup files in the index directory + let (part_page_files, part_lookup_files) = + list_page_lookup_files(object_store, index_dir).await?; + merge_metadata_files(store, &part_page_files, &part_lookup_files, batch_readhead).await +} + /// List and filter files from the index directory /// Returns (page_files, lookup_files) -pub async fn list_page_lookup_files( +async fn list_page_lookup_files( object_store: &ObjectStore, index_dir: &Path, ) -> Result<(Vec, Vec)> { @@ -1482,11 +1494,11 @@ pub async fn list_page_lookup_files( /// /// In a distributed environment, each worker node writes partition page / lookup files for the partitions it processes, /// and this function merges these files into a final metadata file. -pub async fn merge_metadata_files( +async fn merge_metadata_files( store: Arc, part_page_files: &[String], part_lookup_files: &[String], - prefetch_batch: Option, + batch_readhead: Option, ) -> Result<()> { if part_lookup_files.is_empty() || part_page_files.is_empty() { return Err(Error::Internal { @@ -1496,6 +1508,16 @@ pub async fn merge_metadata_files( } // Step 1: Create lookup map for page files by partition ID + if part_lookup_files.len() != part_page_files.len() { + return Err(Error::Internal { + message: format!( + "Number of partition lookup files ({}) does not match number of partition page files ({})", + part_lookup_files.len(), + part_page_files.len() + ), + location: location!(), + }); + } let mut page_files_map = HashMap::new(); for page_file in part_page_files { let partition_id = extract_partition_id(page_file)?; @@ -1526,8 +1548,12 @@ pub async fn merge_metadata_files( .unwrap_or(DEFAULT_BTREE_BATCH_SIZE); // Get the value type from lookup schema (min column) - let lookup_batch = first_lookup_reader.read_range(0..1, None).await?; - let value_type = lookup_batch.column(0).data_type().clone(); + let value_type = first_lookup_reader + .schema() + .fields + .first() + .unwrap() + .data_type(); // Get page schema first let partition_id = extract_partition_id(part_lookup_files[0].as_str())?; @@ -1540,20 +1566,15 @@ pub async fn merge_metadata_files( .new_index_file(BTREE_PAGES_NAME, arrow_schema.clone()) .await?; - let mut prefetch_config = PrefetchConfig::default(); - if let Some(batch) = prefetch_batch { - prefetch_config = prefetch_config.with_prefetch_batch(batch); - } - // Step 4: Merge pages and create lookup entries - let lookup_entries = merge_page( + let lookup_entries = merge_pages( part_lookup_files, &page_files_map, &store, batch_size, &mut page_file, arrow_schema.clone(), - prefetch_config, + batch_readhead, ) .await?; @@ -1603,31 +1624,29 @@ pub async fn merge_metadata_files( Ok(()) } -async fn merge_page( +/// Merge pages using DataFusion's SortPreservingMerge +/// with fixed-size output streams which implementation is the K-way merge algorithm +async fn merge_pages( part_lookup_files: &[String], page_files_map: &HashMap, store: &Arc, batch_size: u64, page_file: &mut Box, arrow_schema: Arc, - prefetch_config: PrefetchConfig, + batch_readhead: Option, ) -> Result> { let mut lookup_entries = Vec::new(); let mut page_idx = 0u32; let start_time = std::time::Instant::now(); debug!( - "Starting multi-way merge with {} partitions using prefetch manager", + "Starting DataFusion SortPreservingMerge with {} partitions", part_lookup_files.len() ); - // Directly create iterators and read first element - let mut partition_map = HashMap::new(); - let mut heap = BinaryHeap::new(); + // Create input streams for each partition + let mut input_streams = Vec::new(); - debug!("Initializing {} partitions", part_lookup_files.len()); - - // Initialize all partitions for lookup_file in part_lookup_files { let partition_id = extract_partition_id(lookup_file)?; let page_file_name = @@ -1639,71 +1658,85 @@ async fn merge_page( })?) .clone(); - let mut iterator = PartitionIterator::new( - store.clone(), - page_file_name, - batch_size, - prefetch_config.clone(), - ) - .await?; + let reader = store.open_index_file(&page_file_name).await?; - let first_element = iterator.next().await?; + let stream = reader + .read_stream(batch_size as u32, batch_readhead.unwrap_or(1) as u32, None) + .await?; - if let Some((value, row_id)) = first_element { - // Put the first element into the heap with cached orderable value - let orderable_value = OrderableScalarValue(value.clone()); - heap.push(HeapElement { - value, - row_id, - partition_id, - orderable_value, - }); - } + input_streams.push(stream); + } - partition_map.insert(partition_id, iterator); + if input_streams.is_empty() { + return Ok(lookup_entries); } - debug!( - "Initialized {} partitions, heap size: {}", - partition_map.len(), - heap.len() - ); + // Convert streams into execution plans + let mut input_plans: Vec> = Vec::new(); + let value_field = arrow_schema.field(0).clone().with_name(VALUE_COLUMN_NAME); + let row_id_field = arrow_schema.field(1).clone().with_name(ROW_ID); + let new_schema = Arc::new(Schema::new(vec![value_field, row_id_field])); + + for stream in input_streams { + // Convert Lance RecordBatchStream to DataFusion SendableRecordBatchStream + let df_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( + new_schema.clone(), + stream.map(|result| result.map_err(|e| DataFusionError::ArrowError(e.into(), None))), + )); + + let plan: Arc = Arc::new(OneShotExec::new(df_stream)); + input_plans.push(plan); + } + + // Create UnionExec to combine all inputs + let union_exec = Arc::new(UnionExec::new(input_plans)); + + // Create sort expression for the first column (value column) + let value_column_index = new_schema.index_of(VALUE_COLUMN_NAME)?; + let sort_expr = PhysicalSortExpr { + expr: Arc::new(Column::new(VALUE_COLUMN_NAME, value_column_index)), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }; + + // Create SortPreservingMergeExec + let merge_exec = Arc::new(SortPreservingMergeExec::new( + LexOrdering::new(vec![sort_expr]), + union_exec, + )); - let mut current_batch_rows = Vec::with_capacity(batch_size as usize); + // Execute the plan + let merged_stream = execute_plan( + merge_exec, + LanceExecutionOptions { + use_spilling: true, + ..Default::default() + }, + )?; + + // Use chunk_concat_stream to ensure fixed-size output batches + let chunked_stream = chunk_concat_stream(merged_stream, batch_size as usize); + let mut chunked_stream = Box::pin(chunked_stream); let mut total_merged = 0usize; - // Multi-way merge main loop - while let Some(min_element) = heap.pop() { - // Add current minimum element to batch - current_batch_rows.push((min_element.value, min_element.row_id)); - total_merged += 1; - - // Read next element from corresponding partition - if let Some(iterator) = partition_map.get_mut(&min_element.partition_id) { - if let Some((next_value, next_row_id)) = iterator.next().await? { - let orderable_value = OrderableScalarValue(next_value.clone()); - heap.push(HeapElement { - value: next_value, - row_id: next_row_id, - partition_id: min_element.partition_id, - orderable_value, - }); - } + // Process the chunked stream - each batch will have exactly batch_size rows + // (except possibly the last one) + while let Some(batch_result) = chunked_stream.next().await { + let batch = batch_result?; + + // Convert the entire batch to lookup entries at once + let mut current_batch_rows = Vec::with_capacity(batch.num_rows()); + for row_idx in 0..batch.num_rows() { + let value = ScalarValue::try_from_array(batch.column(0), row_idx)?; + let row_id = ScalarValue::try_from_array(batch.column(1), row_idx)?; + current_batch_rows.push((value, row_id)); } - if current_batch_rows.len() >= batch_size as usize { - write_batch_and_lookup_entry( - &mut current_batch_rows, - page_file, - &arrow_schema, - &mut lookup_entries, - &mut page_idx, - ) - .await?; - } - } + total_merged += current_batch_rows.len(); - if !current_batch_rows.is_empty() { + // Write the batch (it's already the right size due to chunk_concat_stream) write_batch_and_lookup_entry( &mut current_batch_rows, page_file, @@ -1714,20 +1747,14 @@ async fn merge_page( .await?; } - let elapsed = start_time.elapsed(); - let rows_per_second = if elapsed.as_secs_f64() > 0.0 { - total_merged as f64 / elapsed.as_secs_f64() - } else { - 0.0 - }; - + let duration = start_time.elapsed(); debug!( - "Completed multi-way merge: merged {} rows into {} lookup entries in {:.2}s ({:.0} rows/s)", + "DataFusion merge completed: merged {} records in {:?}, {} lookup entries", total_merged, - lookup_entries.len(), - elapsed.as_secs_f64(), - rows_per_second + duration, + lookup_entries.len() ); + Ok(lookup_entries) } @@ -1898,252 +1925,6 @@ pub(crate) fn part_lookup_file_path(partition_id: u64) -> String { format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME) } -/// Prefetch configuration for partition iterators -#[derive(Debug, Clone)] -pub struct PrefetchConfig { - /// Number of batches to prefetch ahead (0 means no prefetching) - pub prefetch_batches: usize, -} - -impl Default for PrefetchConfig { - fn default() -> Self { - Self { - prefetch_batches: 1, - } - } -} - -impl PrefetchConfig { - /// Set the prefetch batch count - pub fn with_prefetch_batch(&self, batch_count: usize) -> Self { - Self { - prefetch_batches: batch_count, - } - } -} - -/// Buffer entry for prefetch queue -#[derive(Debug)] -struct BufferEntry { - batch: RecordBatch, - start_row: usize, - end_row: usize, -} - -/// Partition iterator for loading partition data in batches with integrated prefetching -struct PartitionIterator { - reader: Arc, - current_batch: Option, - current_position: usize, - rows_read: usize, - batch_size: u64, - prefetch_buffer: Arc>>, - prefetch_task: Option>, -} - -impl PartitionIterator { - async fn new( - store: Arc, - page_file_name: String, - batch_size: u64, - prefetch_config: PrefetchConfig, - ) -> Result { - let reader = store.open_index_file(&page_file_name).await?; - let total_rows = reader.num_rows(); - - // Create shared prefetch buffer - let prefetch_buffer = Arc::new(tokio::sync::Mutex::new(VecDeque::new())); - - // Start prefetch task - let prefetch_task = if prefetch_config.prefetch_batches > 0 { - let reader_clone = reader.clone(); - let buffer_clone = prefetch_buffer.clone(); - let batch_size_clone = batch_size as usize; - let prefetch_batches = prefetch_config.prefetch_batches; - - Some(tokio::spawn(async move { - let mut current_pos = 0; - - while current_pos < total_rows { - let effective_batch_size = prefetch_batches * batch_size_clone; - let end_pos = std::cmp::min(current_pos + effective_batch_size, total_rows); - - match reader_clone.read_range(current_pos..end_pos, None).await { - Ok(batch) => { - let entry = BufferEntry { - start_row: current_pos, - end_row: current_pos + batch.num_rows(), - batch, - }; - - // Add data to buffer with size control - { - let mut buffer_guard = buffer_clone.lock().await; - const MAX_BUFFER_SIZE: usize = 4; - - // Only add if buffer is not full - if buffer_guard.len() < MAX_BUFFER_SIZE { - buffer_guard.push_back(entry); - } else { - // Buffer is full, skip this prefetch to avoid memory bloat - break; - } - } - - current_pos = end_pos; - } - Err(_) => { - // Read failed, wait before retry - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; - } - } - - // Add some delay to avoid excessive resource consumption - tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; - } - })) - } else { - None - }; - - Ok(Self { - reader, - current_batch: None, - current_position: 0, - rows_read: 0, - batch_size, - prefetch_buffer, - prefetch_task, - }) - } - - /// Get the next element with integrated prefetching - async fn next(&mut self) -> Result> { - // Load new batch if current one is exhausted - if self.needs_new_batch() { - if self.rows_read >= self.reader.num_rows() { - return Ok(None); - } - self.load_next_batch().await?; - } - - // Extract next value from current batch - if let Some(batch) = &self.current_batch { - let value = ScalarValue::try_from_array(batch.column(0), self.current_position)?; - let row_id = ScalarValue::try_from_array(batch.column(1), self.current_position)?; - self.current_position += 1; - self.rows_read += 1; - Ok(Some((value, row_id))) - } else { - Ok(None) - } - } - - /// Check if we need to load a new batch - fn needs_new_batch(&self) -> bool { - self.current_batch.is_none() - || self.current_position >= self.current_batch.as_ref().unwrap().num_rows() - } - - async fn load_next_batch(&mut self) -> Result<()> { - let remaining_rows = self.reader.num_rows() - self.rows_read; - if remaining_rows == 0 { - self.current_batch = None; - return Ok(()); - } - - let rows_to_read = std::cmp::min(self.batch_size as usize, remaining_rows); - let end_row = self.rows_read + rows_to_read; - - // First try to get data from prefetch buffer - let batch = if let Some(entry) = self.try_get_from_buffer(self.rows_read, end_row).await { - // Get required data by slicing from prefetch buffer - let slice_start = self.rows_read - entry.start_row; - let slice_len = rows_to_read; - entry.batch.slice(slice_start, slice_len) - } else { - // Fallback to direct read - self.reader - .read_range(self.rows_read..end_row, None) - .await? - }; - - self.current_batch = Some(batch); - self.current_position = 0; - - Ok(()) - } - - /// Try to get data from prefetch buffer - async fn try_get_from_buffer(&self, start_row: usize, end_row: usize) -> Option { - let mut buffer_guard = self.prefetch_buffer.lock().await; - - // Clean up expired buffer entries (entries that are no longer needed) - while let Some(entry) = buffer_guard.front() { - if entry.end_row <= start_row { - buffer_guard.pop_front(); - } else { - break; - } - } - - // Find entry that contains the required data range - if let Some(index) = buffer_guard - .iter() - .position(|entry| entry.start_row <= start_row && entry.end_row >= end_row) - { - buffer_guard.remove(index) - } else { - None - } - } - - #[allow(dead_code)] - fn get_reader(&self) -> Arc { - self.reader.clone() - } -} - -impl Drop for PartitionIterator { - fn drop(&mut self) { - // Cancel the prefetch task when the iterator is dropped - if let Some(task) = &self.prefetch_task { - task.abort(); - } - } -} - -/// Heap elements, used for priority queues in multi-way merging -#[derive(Debug)] -struct HeapElement { - value: ScalarValue, - row_id: ScalarValue, - partition_id: u64, - orderable_value: OrderableScalarValue, -} - -impl PartialEq for HeapElement { - fn eq(&self, other: &Self) -> bool { - self.orderable_value.eq(&other.orderable_value) - } -} - -impl Eq for HeapElement {} - -impl PartialOrd for HeapElement { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for HeapElement { - fn cmp(&self, other: &Self) -> Ordering { - // Note: BinaryHeap is a maximum heap, we need a minimum heap, - // so reverse the comparison result - other.orderable_value.cmp(&self.orderable_value) - } -} - /// A stream that reads the original training data back out of the index /// /// This is used for updating the index @@ -2340,16 +2121,9 @@ mod tests { }; use super::{ - part_lookup_file_path, part_page_data_file_path, train_btree_index, HeapElement, - OrderableScalarValue, DEFAULT_BTREE_BATCH_SIZE, + part_lookup_file_path, part_page_data_file_path, train_btree_index, OrderableScalarValue, + DEFAULT_BTREE_BATCH_SIZE, }; - - use datafusion::arrow::array::Int32Array; - use datafusion::arrow::datatypes::{Field, Schema}; - use datafusion::arrow::record_batch::RecordBatch; - use std::cmp::Ordering as CmpOrdering; - use std::collections::BinaryHeap; - #[test] fn test_scalar_value_size() { let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of(); @@ -3191,189 +2965,4 @@ mod tests { // This test mainly verifies that the function doesn't panic and handles edge cases super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; } - - #[test] - fn test_heap_element_comparison() { - // Create HeapElements with different values - let element1 = HeapElement { - value: ScalarValue::Int32(Some(10)), - row_id: ScalarValue::UInt64(Some(1)), - partition_id: 0, - orderable_value: OrderableScalarValue(ScalarValue::Int32(Some(10))), - }; - - let element2 = HeapElement { - value: ScalarValue::Int32(Some(5)), - row_id: ScalarValue::UInt64(Some(2)), - partition_id: 0, - orderable_value: OrderableScalarValue(ScalarValue::Int32(Some(5))), - }; - - let element3 = HeapElement { - value: ScalarValue::Int32(Some(15)), - row_id: ScalarValue::UInt64(Some(3)), - partition_id: 0, - orderable_value: OrderableScalarValue(ScalarValue::Int32(Some(15))), - }; - - // Test direct comparison - note that BinaryHeap is a max heap, - // but we want min heap behavior, so smaller values should be "greater" - assert_eq!(element1.cmp(&element2), CmpOrdering::Less); // 10 < 5 in min heap (5 should come first) - assert_eq!(element2.cmp(&element1), CmpOrdering::Greater); // 5 > 10 in min heap - assert_eq!(element1.cmp(&element3), CmpOrdering::Greater); // 10 > 15 in min heap (10 should come first) - - // Test with BinaryHeap to ensure min heap behavior - let mut heap = BinaryHeap::new(); - heap.push(element1); - heap.push(element2); - heap.push(element3); - - // Should pop in ascending order (min heap behavior) - let first = heap.pop().unwrap(); - assert_eq!(first.value, ScalarValue::Int32(Some(5))); - - let second = heap.pop().unwrap(); - assert_eq!(second.value, ScalarValue::Int32(Some(10))); - - let third = heap.pop().unwrap(); - assert_eq!(third.value, ScalarValue::Int32(Some(15))); - - // Test with null values - let null_element = HeapElement { - value: ScalarValue::Int32(None), - row_id: ScalarValue::UInt64(Some(4)), - partition_id: 0, - orderable_value: OrderableScalarValue(ScalarValue::Int32(None)), - }; - - let non_null_element = HeapElement { - value: ScalarValue::Int32(Some(1)), - row_id: ScalarValue::UInt64(Some(5)), - partition_id: 0, - orderable_value: OrderableScalarValue(ScalarValue::Int32(Some(1))), - }; - - // Null values should come before non-null values in min heap - assert_eq!(null_element.cmp(&non_null_element), CmpOrdering::Greater); - assert_eq!(non_null_element.cmp(&null_element), CmpOrdering::Less); - - // Test heap with null values - let mut heap_with_null = BinaryHeap::new(); - heap_with_null.push(non_null_element); - heap_with_null.push(null_element); - - // Null should come first - let first_with_null = heap_with_null.pop().unwrap(); - assert_eq!(first_with_null.value, ScalarValue::Int32(None)); - - let second_with_null = heap_with_null.pop().unwrap(); - assert_eq!(second_with_null.value, ScalarValue::Int32(Some(1))); - } - - #[tokio::test] - async fn test_partition_iterator() { - // Create test environment - let tmpdir = Arc::new(tempdir().unwrap()); - let test_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - Path::from_filesystem_path(tmpdir.path()).unwrap(), - Arc::new(LanceCache::no_cache()), - )); - - // Create test data with more rows to test iteration - let schema = Arc::new(Schema::new(vec![ - Field::new("value", DataType::Int32, false), - Field::new("row_id", DataType::UInt64, false), - ])); - - let values = Int32Array::from(vec![10, 20, 30, 40, 50, 60, 70, 80]); - let row_ids = arrow::array::UInt64Array::from(vec![1u64, 2, 3, 4, 5, 6, 7, 8]); - let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(values), Arc::new(row_ids)]) - .unwrap(); - - // Write test data to index file - let page_file_name = "test_partition_pages.lance".to_string(); - let mut index_writer = test_store - .new_index_file(&page_file_name, schema.clone()) - .await - .unwrap(); - index_writer.write_record_batch(batch).await.unwrap(); - index_writer.finish().await.unwrap(); - - // Test PartitionIterator creation and basic methods - let batch_size = 3u64; // Small batch size to test multiple iterations - - // Create PartitionIterator - let mut partition_iterator = super::PartitionIterator::new( - test_store.clone(), - page_file_name, - batch_size, - super::PrefetchConfig::default(), - ) - .await - .unwrap(); - - // Test that iterator was created with correct initial state - assert_eq!(partition_iterator.batch_size, batch_size); - assert_eq!(partition_iterator.current_position, 0); - assert_eq!(partition_iterator.rows_read, 0); - assert!(partition_iterator.current_batch.is_none()); - - // Test get_reader method - let reader = partition_iterator.get_reader(); - assert_eq!(reader.num_rows(), 8); // We wrote 8 rows - - // Test iteration through data - let mut collected_values = Vec::new(); - let mut collected_row_ids = Vec::new(); - let mut iteration_count = 0; - - // Iterate through all data - while let Some((value, row_id)) = partition_iterator.next().await.unwrap() { - collected_values.push(value); - collected_row_ids.push(row_id); - iteration_count += 1; - - // Prevent infinite loop in case of bugs - if iteration_count > 20 { - panic!("Too many iterations, possible infinite loop"); - } - } - - // Verify we got all the data - assert_eq!(collected_values.len(), 8); - assert_eq!(collected_row_ids.len(), 8); - - // Verify the actual values (they should match what we wrote) - let expected_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::Int32(Some(20)), - ScalarValue::Int32(Some(30)), - ScalarValue::Int32(Some(40)), - ScalarValue::Int32(Some(50)), - ScalarValue::Int32(Some(60)), - ScalarValue::Int32(Some(70)), - ScalarValue::Int32(Some(80)), - ]; - - let expected_row_ids = vec![ - ScalarValue::UInt64(Some(1)), - ScalarValue::UInt64(Some(2)), - ScalarValue::UInt64(Some(3)), - ScalarValue::UInt64(Some(4)), - ScalarValue::UInt64(Some(5)), - ScalarValue::UInt64(Some(6)), - ScalarValue::UInt64(Some(7)), - ScalarValue::UInt64(Some(8)), - ]; - - assert_eq!(collected_values, expected_values); - assert_eq!(collected_row_ids, expected_row_ids); - - // Test that iterator is exhausted - assert!(partition_iterator.next().await.unwrap().is_none()); - - // Verify final state - assert_eq!(partition_iterator.rows_read, 8); - } } diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index c9efe1bf801..137219574b2 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -788,12 +788,21 @@ pub(crate) fn part_metadata_file_path(partition_id: u64) -> String { format!("part_{}_{}", partition_id, METADATA_FILE) } -/// List and filter metadata files from the index directory -/// Returns partition metadata files -pub async fn list_metadata_files( +pub async fn merge_index_files( object_store: &ObjectStore, index_dir: &Path, -) -> Result> { + store: Arc, +) -> Result<()> { + // List all partition metadata files in the index directory + let part_metadata_files = list_metadata_files(object_store, index_dir).await?; + + // Call merge_metadata_files function for inverted index + merge_metadata_files(store, &part_metadata_files).await +} + +/// List and filter metadata files from the index directory +/// Returns partition metadata files +async fn list_metadata_files(object_store: &ObjectStore, index_dir: &Path) -> Result> { // List all partition metadata files in the index directory let mut part_metadata_files = Vec::new(); let mut list_stream = object_store.list(Some(index_dir.clone())); @@ -826,7 +835,7 @@ pub async fn list_metadata_files( } /// Merge partition metadata files with partition ID remapping to sequential IDs starting from 0 -pub async fn merge_metadata_files( +async fn merge_metadata_files( store: Arc, part_metadata_files: &[String], ) -> Result<()> { diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index 4df502ead09..cd62e422989 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -3,14 +3,10 @@ //! Utilities for serializing and deserializing scalar indices in the lance format -use std::cmp::min; -use std::collections::HashMap; -use std::{any::Any, sync::Arc}; - +use super::{IndexReader, IndexStore, IndexWriter}; use arrow_array::RecordBatch; use arrow_schema::Schema; use async_trait::async_trait; - use deepsize::DeepSizeOf; use futures::TryStreamExt; use lance_core::{cache::LanceCache, Error, Result}; @@ -22,12 +18,15 @@ use lance_file::{ writer::{FileWriter, ManifestProvider}, }; use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; +use lance_io::stream::RecordBatchStream; use lance_io::utils::CachedFileSize; use lance_io::{object_store::ObjectStore, ReadBatchParams}; use lance_table::format::SelfDescribingFileReader; use object_store::path::Path; - -use super::{IndexReader, IndexStore, IndexWriter}; +use std::cmp::min; +use std::collections::HashMap; +use std::pin::Pin; +use std::{any::Any, sync::Arc}; /// An index store that serializes scalar indices using the lance format /// @@ -128,6 +127,15 @@ impl IndexReader for FileReader { self.read_range(range, &projection).await } + async fn read_stream( + &self, + _batch_size: u32, + _batch_readahead: u32, + _projection: Option<&[&str]>, + ) -> Result>> { + unimplemented!("Unsupported operation in IndexReader for FileReader."); + } + async fn num_batches(&self, _batch_size: u64) -> u32 { self.num_batches() as u32 } @@ -186,6 +194,33 @@ impl IndexReader for v2::reader::FileReader { Ok(batches[0].clone()) } + async fn read_stream( + &self, + batch_size: u32, + batch_readahead: u32, + projection: Option<&[&str]>, + ) -> Result>> { + let projection = if let Some(projection) = projection { + v2::reader::ReaderProjection::from_column_names( + self.metadata().version(), + self.schema(), + projection, + )? + } else { + v2::reader::ReaderProjection::from_whole_schema( + self.schema(), + self.metadata().version(), + ) + }; + self.read_stream_projected( + ReadBatchParams::RangeFull, + batch_size, + batch_readahead, + projection, + FilterExpression::no_filter(), + ) + } + // V2 format has removed the row group concept, // so here we assume each batch is with 4096 rows. async fn num_batches(&self, batch_size: u64) -> u32 { diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index 586b0a4da9a..4d23ac5af36 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -1287,6 +1287,7 @@ impl ScalarIndexPlugin for NGramIndexPlugin { _request: Box, _fragment_ids: Option>, ) -> Result { + assert!(_fragment_ids.is_none()); Self::train_ngram_index(data, index_store).await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pb::NGramIndexDetails::default()).unwrap(), diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index c9097cbb8cc..43dfbbdb0ef 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -963,6 +963,7 @@ impl ScalarIndexPlugin for ZoneMapIndexPlugin { request: Box, _fragment_ids: Option>, ) -> Result { + assert!(_fragment_ids.is_none()); let request = (request as Box) .downcast::() .map_err(|_| Error::InvalidInput { From 8d0a5b15fc4251880afef7a4b18b6d35a73f76fb Mon Sep 17 00:00:00 2001 From: xloya Date: Fri, 12 Sep 2025 17:33:55 +0800 Subject: [PATCH 10/13] refactor code --- rust/lance-index/src/scalar/btree.rs | 191 +++++++-------------------- 1 file changed, 48 insertions(+), 143 deletions(-) diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index bd80c0919c5..8d691d254c5 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -1629,8 +1629,8 @@ async fn merge_metadata_files( Ok(()) } -/// Merge pages using DataFusion's SortPreservingMerge -/// with fixed-size output streams which implementation is the K-way merge algorithm +/// Merge pages using Datafusion's SortPreservingMergeExec +/// which implements a K-way merge algorithm with fixed-size output batches async fn merge_pages( part_lookup_files: &[String], page_files_map: &HashMap, @@ -1643,14 +1643,16 @@ async fn merge_pages( let mut lookup_entries = Vec::new(); let mut page_idx = 0u32; - let start_time = std::time::Instant::now(); debug!( - "Starting DataFusion SortPreservingMerge with {} partitions", + "Starting SortPreservingMerge with {} partitions", part_lookup_files.len() ); - // Create input streams for each partition - let mut input_streams = Vec::new(); + let mut streams: Vec = Vec::new(); + + let value_field = arrow_schema.field(0).clone().with_name(VALUE_COLUMN_NAME); + let row_id_field = arrow_schema.field(1).clone().with_name(ROW_ID); + let stream_schema = Arc::new(Schema::new(vec![value_field, row_id_field])); for lookup_file in part_lookup_files { let partition_id = extract_partition_id(lookup_file)?; @@ -1665,39 +1667,30 @@ async fn merge_pages( let reader = store.open_index_file(&page_file_name).await?; - let stream = reader - .read_stream(batch_size as u32, batch_readhead.unwrap_or(1) as u32, None) - .await?; + let reader_stream = IndexReaderStream::new(reader, batch_size).await; - input_streams.push(stream); - } + let stream = reader_stream + .map(|fut| fut.map_err(DataFusionError::from)) + .buffered(batch_readhead.unwrap_or(1)) + .boxed(); - if input_streams.is_empty() { - return Ok(lookup_entries); + let sendable_stream = + Box::pin(RecordBatchStreamAdapter::new(stream_schema.clone(), stream)); + streams.push(sendable_stream); } - // Convert streams into execution plans - let mut input_plans: Vec> = Vec::new(); - let value_field = arrow_schema.field(0).clone().with_name(VALUE_COLUMN_NAME); - let row_id_field = arrow_schema.field(1).clone().with_name(ROW_ID); - let new_schema = Arc::new(Schema::new(vec![value_field, row_id_field])); - - for stream in input_streams { - // Convert Lance RecordBatchStream to DataFusion SendableRecordBatchStream - let df_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( - new_schema.clone(), - stream.map(|result| result.map_err(|e| DataFusionError::ArrowError(e.into(), None))), - )); - - let plan: Arc = Arc::new(OneShotExec::new(df_stream)); - input_plans.push(plan); + // Create execution plans for each stream + let mut inputs: Vec> = Vec::new(); + for stream in streams { + let plan = Arc::new(OneShotExec::new(stream)); + inputs.push(plan); } - // Create UnionExec to combine all inputs - let union_exec = Arc::new(UnionExec::new(input_plans)); + // Create Union execution plan to combine all partitions + let union_inputs = Arc::new(UnionExec::new(inputs)); - // Create sort expression for the first column (value column) - let value_column_index = new_schema.index_of(VALUE_COLUMN_NAME)?; + // Create SortPreservingMerge execution plan + let value_column_index = stream_schema.index_of(VALUE_COLUMN_NAME)?; let sort_expr = PhysicalSortExpr { expr: Arc::new(Column::new(VALUE_COLUMN_NAME, value_column_index)), options: SortOptions { @@ -1706,128 +1699,40 @@ async fn merge_pages( }, }; - // Create SortPreservingMergeExec let merge_exec = Arc::new(SortPreservingMergeExec::new( - LexOrdering::new(vec![sort_expr]), - union_exec, + [sort_expr].into(), + union_inputs, )); - // Execute the plan - let merged_stream = execute_plan( + let unchunked = execute_plan( merge_exec, LanceExecutionOptions { - use_spilling: true, + use_spilling: false, ..Default::default() }, )?; - // Use chunk_concat_stream to ensure fixed-size output batches - let chunked_stream = chunk_concat_stream(merged_stream, batch_size as usize); - let mut chunked_stream = Box::pin(chunked_stream); - let mut total_merged = 0usize; - - // Process the chunked stream - each batch will have exactly batch_size rows - // (except possibly the last one) - while let Some(batch_result) = chunked_stream.next().await { - let batch = batch_result?; - - // Convert the entire batch to lookup entries at once - let mut current_batch_rows = Vec::with_capacity(batch.num_rows()); - for row_idx in 0..batch.num_rows() { - let value = ScalarValue::try_from_array(batch.column(0), row_idx)?; - let row_id = ScalarValue::try_from_array(batch.column(1), row_idx)?; - current_batch_rows.push((value, row_id)); - } - - total_merged += current_batch_rows.len(); - - // Write the batch (it's already the right size due to chunk_concat_stream) - write_batch_and_lookup_entry( - &mut current_batch_rows, - page_file, - &arrow_schema, - &mut lookup_entries, - &mut page_idx, - ) - .await?; - } - - let duration = start_time.elapsed(); - debug!( - "DataFusion merge completed: merged {} records in {:?}, {} lookup entries", - total_merged, - duration, - lookup_entries.len() - ); - - Ok(lookup_entries) -} - -/// Helper function to prepare batch data with optimized memory usage -async fn prepare_batch_data( - batch_rows: Vec<(ScalarValue, ScalarValue)>, - arrow_schema: Arc, - page_idx: u32, -) -> Result<(RecordBatch, (ScalarValue, ScalarValue, u32, u32))> { - if batch_rows.is_empty() { - return Err(Error::Internal { - message: "Cannot prepare empty batch".to_string(), - location: location!(), - }); - } - - let capacity = batch_rows.len(); - let mut values = Vec::with_capacity(capacity); - let mut row_ids = Vec::with_capacity(capacity); - - for (value, row_id) in batch_rows.into_iter() { - values.push(value); - row_ids.push(row_id); - } - - let (values_array, row_ids_array) = rayon::join( - || ScalarValue::iter_to_array(values.into_iter()), - || ScalarValue::iter_to_array(row_ids.into_iter()), - ); + // Use chunk_concat_stream to ensure fixed batch sizes + let mut chunked_stream = chunk_concat_stream(unchunked, batch_size as usize); - let values_array = values_array?; - let row_ids_array = row_ids_array?; - - let batch = RecordBatch::try_new(arrow_schema, vec![values_array, row_ids_array])?; - - // Calculate min/max/null_count for lookup entry - let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; - let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; - let null_count = batch.column(0).null_count() as u32; + // Process chunked stream + while let Some(batch) = chunked_stream.try_next().await? { + let writer_batch = RecordBatch::try_new( + arrow_schema.clone(), + vec![batch.column(0).clone(), batch.column(1).clone()], + )?; - let lookup_entry = (min_val, max_val, null_count, page_idx); + page_file.write_record_batch(writer_batch).await?; - Ok((batch, lookup_entry)) -} + let min_val = ScalarValue::try_from_array(batch.column(0), 0)?; + let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?; + let null_count = batch.column(0).null_count() as u32; -/// Helper function to write a batch and create lookup entry -async fn write_batch_and_lookup_entry( - batch_rows: &mut Vec<(ScalarValue, ScalarValue)>, - page_file: &mut Box, - arrow_schema: &Arc, - lookup_entries: &mut Vec<(ScalarValue, ScalarValue, u32, u32)>, - page_idx: &mut u32, -) -> Result<()> { - if batch_rows.is_empty() { - return Ok(()); + lookup_entries.push((min_val, max_val, null_count, page_idx)); + page_idx += 1; } - let batch_data = std::mem::take(batch_rows); - let current_page_idx = *page_idx; - - let (batch, lookup_entry) = - prepare_batch_data(batch_data, arrow_schema.clone(), current_page_idx).await?; - - lookup_entries.push(lookup_entry); - page_file.write_record_batch(batch).await?; - *page_idx += 1; - - Ok(()) + Ok(lookup_entries) } /// Extract partition ID from partition file name @@ -2413,11 +2318,11 @@ mod tests { .unwrap(); // Load both indexes - let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) + let full_index = BTreeIndex::load(full_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); - let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) + let merged_index = BTreeIndex::load(fragment_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); @@ -2624,11 +2529,11 @@ mod tests { .unwrap(); // Load both indexes - let full_index = BTreeIndex::load(full_store.clone(), None, LanceCache::no_cache()) + let full_index = BTreeIndex::load(full_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); - let merged_index = BTreeIndex::load(fragment_store.clone(), None, LanceCache::no_cache()) + let merged_index = BTreeIndex::load(fragment_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); From 9a2c3fef55c2c8a28400f510d53b5553515f92cd Mon Sep 17 00:00:00 2001 From: xloya Date: Fri, 12 Sep 2025 17:37:37 +0800 Subject: [PATCH 11/13] remove useless code --- rust/lance-index/src/scalar.rs | 9 ----- rust/lance-index/src/scalar/lance_format.rs | 38 --------------------- 2 files changed, 47 deletions(-) diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 5ac0a0fce10..deeec8f7876 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -13,7 +13,6 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::{scalar::ScalarValue, Column}; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; -use std::pin::Pin; use std::{any::Any, ops::Bound, sync::Arc}; use datafusion_expr::expr::ScalarFunction; @@ -44,7 +43,6 @@ pub mod zonemap; use crate::frag_reuse::FragReuseIndex; pub use inverted::tokenizer::InvertedIndexParams; use lance_datafusion::udf::CONTAINS_TOKENS_UDF; -use lance_io::stream::RecordBatchStream; pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index"; @@ -190,13 +188,6 @@ pub trait IndexReader: Send + Sync { range: std::ops::Range, projection: Option<&[&str]>, ) -> Result; - /// Reads data from the file as a stream of record batches - async fn read_stream( - &self, - batch_size: u32, - batch_readahead: u32, - projection: Option<&[&str]>, - ) -> Result>>; /// Return the number of batches in the file async fn num_batches(&self, batch_size: u64) -> u32; /// Return the number of rows in the file diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index 7471d62c9e3..e09813ee321 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -18,14 +18,12 @@ use lance_file::{ writer::{FileWriter, ManifestProvider}, }; use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; -use lance_io::stream::RecordBatchStream; use lance_io::utils::CachedFileSize; use lance_io::{object_store::ObjectStore, ReadBatchParams}; use lance_table::format::SelfDescribingFileReader; use object_store::path::Path; use std::cmp::min; use std::collections::HashMap; -use std::pin::Pin; use std::{any::Any, sync::Arc}; /// An index store that serializes scalar indices using the lance format @@ -127,15 +125,6 @@ impl IndexReader for FileReader { self.read_range(range, &projection).await } - async fn read_stream( - &self, - _batch_size: u32, - _batch_readahead: u32, - _projection: Option<&[&str]>, - ) -> Result>> { - unimplemented!("Unsupported operation in IndexReader for FileReader."); - } - async fn num_batches(&self, _batch_size: u64) -> u32 { self.num_batches() as u32 } @@ -194,33 +183,6 @@ impl IndexReader for v2::reader::FileReader { Ok(batches[0].clone()) } - async fn read_stream( - &self, - batch_size: u32, - batch_readahead: u32, - projection: Option<&[&str]>, - ) -> Result>> { - let projection = if let Some(projection) = projection { - v2::reader::ReaderProjection::from_column_names( - self.metadata().version(), - self.schema(), - projection, - )? - } else { - v2::reader::ReaderProjection::from_whole_schema( - self.schema(), - self.metadata().version(), - ) - }; - self.read_stream_projected( - ReadBatchParams::RangeFull, - batch_size, - batch_readahead, - projection, - FilterExpression::no_filter(), - ) - } - // V2 format has removed the row group concept, // so here we assume each batch is with 4096 rows. async fn num_batches(&self, batch_size: u64) -> u32 { From 98b7ac2b7a23d10f1d07f1665171dcfc7894917d Mon Sep 17 00:00:00 2001 From: xloya Date: Mon, 15 Sep 2025 14:26:16 +0800 Subject: [PATCH 12/13] simplify code --- rust/lance-index/src/scalar/bloomfilter.rs | 1 + rust/lance-index/src/scalar/btree.rs | 13 +++---------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/rust/lance-index/src/scalar/bloomfilter.rs b/rust/lance-index/src/scalar/bloomfilter.rs index b5b5a1a58cc..65761dd79ce 100644 --- a/rust/lance-index/src/scalar/bloomfilter.rs +++ b/rust/lance-index/src/scalar/bloomfilter.rs @@ -1208,6 +1208,7 @@ impl ScalarIndexPlugin for BloomFilterIndexPlugin { data: SendableRecordBatchStream, index_store: &dyn IndexStore, request: Box, + _fragment_ids: Option>, ) -> Result { let request = (request as Box) .downcast::() diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index c579b23b3ef..304b31716d9 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -1655,12 +1655,12 @@ async fn merge_pages( part_lookup_files.len() ); - let mut streams: Vec = Vec::new(); - let value_field = arrow_schema.field(0).clone().with_name(VALUE_COLUMN_NAME); let row_id_field = arrow_schema.field(1).clone().with_name(ROW_ID); let stream_schema = Arc::new(Schema::new(vec![value_field, row_id_field])); + // Create execution plans for each stream + let mut inputs: Vec> = Vec::new(); for lookup_file in part_lookup_files { let partition_id = extract_partition_id(lookup_file)?; let page_file_name = @@ -1683,14 +1683,7 @@ async fn merge_pages( let sendable_stream = Box::pin(RecordBatchStreamAdapter::new(stream_schema.clone(), stream)); - streams.push(sendable_stream); - } - - // Create execution plans for each stream - let mut inputs: Vec> = Vec::new(); - for stream in streams { - let plan = Arc::new(OneShotExec::new(stream)); - inputs.push(plan); + inputs.push(Arc::new(OneShotExec::new(sendable_stream))); } // Create Union execution plan to combine all partitions From ab7fe61e6281b851ac854268f665edfe5ddc00cb Mon Sep 17 00:00:00 2001 From: xloya Date: Mon, 15 Sep 2025 15:42:52 +0800 Subject: [PATCH 13/13] update --- rust/lance-index/src/scalar/bloomfilter.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/lance-index/src/scalar/bloomfilter.rs b/rust/lance-index/src/scalar/bloomfilter.rs index 65761dd79ce..49dd64c03b6 100644 --- a/rust/lance-index/src/scalar/bloomfilter.rs +++ b/rust/lance-index/src/scalar/bloomfilter.rs @@ -1210,6 +1210,7 @@ impl ScalarIndexPlugin for BloomFilterIndexPlugin { request: Box, _fragment_ids: Option>, ) -> Result { + assert!(_fragment_ids.is_none()); let request = (request as Box) .downcast::() .map_err(|_| Error::InvalidInput {