diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a87f4e24d06..94605239897 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -63,7 +63,7 @@ jobs: sudo apt install -y protobuf-compiler libssl-dev - name: Get features run: | - ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | sort | uniq | paste -s -d "," -` echo "ALL_FEATURES=${ALL_FEATURES}" >> $GITHUB_ENV - name: Clippy run: cargo clippy --profile ci --locked --features ${{ env.ALL_FEATURES }} --all-targets -- -D warnings @@ -104,7 +104,7 @@ jobs: uses: taiki-e/install-action@cargo-llvm-cov - name: Run tests run: | - ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v -e protoc -e slow_tests | sort | uniq | paste -s -d "," -` cargo +nightly llvm-cov --profile ci --locked --workspace --codecov --output-path coverage.codecov --features ${ALL_FEATURES} - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 @@ -131,14 +131,41 @@ jobs: sudo apt install -y protobuf-compiler libssl-dev pkg-config - name: Build tests run: | - ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v -e protoc -e slow_tests | sort | uniq | paste -s -d "," -` cargo test --profile ci --locked --features ${ALL_FEATURES} --no-run - name: Start DynamodDB and S3 run: docker compose -f docker-compose.yml up -d --wait - name: Run tests run: | - ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v -e protoc -e slow_tests | sort | uniq | paste -s -d "," -` cargo test --profile ci --locked --features ${ALL_FEATURES} + query-integration-tests: + runs-on: warp-ubuntu-latest-x64-4x + timeout-minutes: 75 + env: + # We use opt-level 1 which makes some tests 5x faster to run. + RUSTFLAGS: "-C debuginfo=1 -C opt-level=1" + steps: + - uses: actions/checkout@v4 + - name: Setup rust toolchain + run: | + rustup toolchain install stable + rustup default stable + - uses: rui314/setup-mold@v1 + - uses: Swatinem/rust-cache@v2 + with: + cache-targets: false + cache-workspace-crates: true + - name: Install dependencies + run: | + sudo apt -y -qq update + sudo apt install -y protobuf-compiler libssl-dev pkg-config + - name: Build query integration tests + run: | + cargo build --locked -p lance --no-default-features --features fp16kernels,slow_tests --tests --test integration_tests + - name: Run query integration tests + run: | + cargo test --locked -p lance --no-default-features --features fp16kernels,slow_tests --test integration_tests build-no-lock: runs-on: warp-ubuntu-latest-x64-8x timeout-minutes: 30 @@ -158,7 +185,7 @@ jobs: sudo apt install -y protobuf-compiler libssl-dev - name: Build all run: | - ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v -e protoc -e slow_tests | sort | uniq | paste -s -d "," -` cargo build --profile ci --benches --features ${ALL_FEATURES} --tests mac-build: runs-on: warp-macos-14-arm64-6x @@ -242,5 +269,5 @@ jobs: rustup default ${{ matrix.msrv }} - name: cargo +${{ matrix.msrv }} check run: | - ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v -e protoc -e slow_tests | sort | uniq | paste -s -d "," -` cargo check --profile ci --workspace --tests --benches --features ${ALL_FEATURES} diff --git a/Cargo.lock b/Cargo.lock index f89d5c74c1a..95923c70bd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4460,6 +4460,7 @@ dependencies = [ "mock_instant", "moka", "object_store", + "paste", "permutation", "pin-project", "pprof", diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index c422a5bcf45..25c30230b35 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -106,6 +106,7 @@ test-log.workspace = true tracing-chrome = "0.7.1" rstest = { workspace = true } tracking-allocator = { version = "0.4", features = ["tracing-compat"] } +paste = "1.0" # For S3 / DynamoDB tests aws-config = { workspace = true } aws-sdk-s3 = { workspace = true } @@ -133,6 +134,8 @@ gcp = ["lance-io/gcp"] azure = ["lance-io/azure"] oss = ["lance-io/oss"] huggingface = ["lance-io/huggingface"] +# Enable slow integration tests (disabled by default in CI) +slow_tests = [] [[bin]] name = "lq" diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index 2724b3a3cb4..fdac7395dff 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -229,7 +229,16 @@ impl<'a> CreateIndexBuilder<'a> { ) .await? } - (IndexType::Vector, LANCE_VECTOR_INDEX) => { + ( + IndexType::Vector + | IndexType::IvfPq + | IndexType::IvfSq + | IndexType::IvfFlat + | IndexType::IvfHnswFlat + | IndexType::IvfHnswPq + | IndexType::IvfHnswSq, + LANCE_VECTOR_INDEX, + ) => { // Vector index params. let vec_params = self .params diff --git a/rust/lance/tests/integration_tests.rs b/rust/lance/tests/integration_tests.rs new file mode 100644 index 00000000000..81c2535dd9c --- /dev/null +++ b/rust/lance/tests/integration_tests.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +// NOTE: we only create one integration test binary, to keep compilation overhead down. + +#[cfg(feature = "slow_tests")] +mod query; +#[cfg(feature = "slow_tests")] +mod utils; diff --git a/rust/lance/tests/query/mod.rs b/rust/lance/tests/query/mod.rs new file mode 100644 index 00000000000..5816d786f89 --- /dev/null +++ b/rust/lance/tests/query/mod.rs @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use arrow_array::{cast::AsArray, RecordBatch, UInt32Array}; +use arrow_select::concat::concat_batches; +use datafusion::datasource::MemTable; +use datafusion::prelude::SessionContext; +use lance::dataset::scanner::ColumnOrdering; +use lance::Dataset; +use lance_datafusion::udf::register_functions; + +/// Creates a fresh SessionContext with Lance UDFs registered +fn create_datafusion_context() -> SessionContext { + let ctx = SessionContext::new(); + register_functions(&ctx); + ctx +} + +mod primitives; +mod vectors; + +/// Scanning and ordering by id should give same result as original. +async fn test_scan(original: &RecordBatch, ds: &Dataset) { + let mut scanner = ds.scan(); + scanner + .order_by(Some(vec![ColumnOrdering::asc_nulls_first( + "id".to_string(), + )])) + .unwrap(); + let scanned = scanner.try_into_batch().await.unwrap(); + + assert_eq!(original, &scanned); +} + +/// Taking specific rows should give the same result as taking from the original. +async fn test_take(original: &RecordBatch, ds: &Dataset) { + let num_rows = original.num_rows(); + let cases: Vec> = vec![ + vec![0, 1, 2], // First few rows + vec![5, 3, 1], // Out of order + vec![0], // Single row + vec![], // Empty + (0..num_rows.min(10)).collect(), // Sequential + vec![num_rows - 1, 0], // Last and first + vec![1, 1, 2], // Duplicate indices + vec![0, 0, 0], // All same index + vec![num_rows - 1, num_rows - 1], // Duplicate of last row + ]; + + for indices in cases { + // Convert to u64 for Lance take + let indices_u64: Vec = indices.iter().map(|&i| i as u64).collect(); + + let taken_ds = ds.take(&indices_u64, ds.schema().clone()).await.unwrap(); + + // Take from RecordBatch using arrow::compute + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let indices_array = UInt32Array::from(indices_u32); + let taken_rb = arrow::compute::take_record_batch(original, &indices_array).unwrap(); + + assert_eq!( + taken_rb, taken_ds, + "Take results don't match for indices: {:?}", + indices + ); + } +} + +/// Querying with filter should give same result as filtering original +/// record batch in DataFusion. +async fn test_filter(original: &RecordBatch, ds: &Dataset, predicate: &str) { + // Scan with filter and order + let mut scanner = ds.scan(); + scanner + .filter(predicate) + .unwrap() + .order_by(Some(vec![ColumnOrdering::asc_nulls_first( + "id".to_string(), + )])) + .unwrap(); + let scanned = scanner.try_into_batch().await.unwrap(); + + let ctx = create_datafusion_context(); + let table = MemTable::try_new(original.schema(), vec![vec![original.clone()]]).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let sql = format!("SELECT * FROM t WHERE {} ORDER BY id", predicate); + let df = ctx.sql(&sql).await.unwrap(); + let expected_batches = df.collect().await.unwrap(); + let expected = concat_batches(&original.schema(), &expected_batches).unwrap(); + + assert_eq!(&expected, &scanned); +} + +/// Test that an exhaustive ANN query gives the same results as brute force +/// KNN against the original batch. +/// +/// By exhaustive ANN, I mean we search all the partitions so we get perfect recall. +async fn test_ann(original: &RecordBatch, ds: &Dataset, column: &str, predicate: Option<&str>) { + // Extract first vector from the column as query vector + let vector_column = original.column_by_name(column).unwrap(); + let fixed_size_list = vector_column.as_fixed_size_list(); + + // Extract the first vector's values as a new array + let vector_values = fixed_size_list + .values() + .slice(0, fixed_size_list.value_length() as usize); + let query_vector = vector_values; + + let mut scanner = ds.scan(); + scanner + .nearest(column, query_vector.as_ref(), 10) + .unwrap() + .prefilter(true) + .refine(2); + if let Some(pred) = predicate { + scanner.filter(pred).unwrap(); + } + let result = scanner.try_into_batch().await.unwrap(); + + // Use DataFusion to apply same vector search using SQL + let ctx = create_datafusion_context(); + let table = MemTable::try_new(original.schema(), vec![vec![original.clone()]]).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + // Convert query vector to SQL array literal + let float_array = query_vector.as_primitive::(); + let vector_values_str = float_array + .values() + .iter() + .map(|v| v.to_string()) + .collect::>() + .join(", "); + + // DataFusion's built-in `array_distance` function uses L2 distance. + let sql = format!( + "SELECT * FROM t {} ORDER BY array_distance(t.{}, [{}]) LIMIT 10", + if let Some(pred) = predicate { + format!("WHERE {}", pred) + } else { + String::new() + }, + column, + vector_values_str + ); + + let df = ctx.sql(&sql).await.unwrap(); + let expected_batches = df.collect().await.unwrap(); + let expected = concat_batches(&original.schema(), &expected_batches).unwrap(); + + // Compare only the main data (excluding _distance column which Lance adds). + // We validate that both return the same number of rows and same row ordering. + // Note: We don't validate the _distance column values because: + // 1. ANN indices provide approximate distances, not exact values + // 2. Some distance functions return ordering values (e.g., squared euclidean + // without the final sqrt step) rather than true distances + assert_eq!( + expected.num_rows(), + result.num_rows(), + "Different number of results" + ); + + // Compare the first few columns (excluding _distance) + for (col_idx, field) in original.schema().fields().iter().enumerate() { + let expected_col = expected.column(col_idx); + let result_col = result.column(col_idx); + assert_eq!( + expected_col, + result_col, + "Column '{}' differs between DataFusion and Lance results", + field.name() + ); + } +} diff --git a/rust/lance/tests/query/primitives.rs b/rust/lance/tests/query/primitives.rs new file mode 100644 index 00000000000..cb978074d37 --- /dev/null +++ b/rust/lance/tests/query/primitives.rs @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use arrow::datatypes::*; +use arrow_array::RecordBatch; +use arrow_schema::DataType; +use lance::Dataset; + +use lance_datagen::{array, gen_batch, ArrayGeneratorExt, RowCount}; +use lance_index::IndexType; + +use super::{test_filter, test_scan, test_take}; +use crate::utils::DatasetTestCases; + +#[tokio::test] +async fn test_query_bool() { + let batch = gen_batch() + .col("id", array::step::()) + .col( + "value", + array::cycle_bool(vec![true, false]).with_random_nulls(0.1), + ) + .into_batch_rows(RowCount::from(60)) + .unwrap(); + DatasetTestCases::from_data(batch) + .with_index_types( + "value", + // TODO: fix bug with bitmap and btree https://github.com/lancedb/lance/issues/4756 + // TODO: fix bug with zone map https://github.com/lancedb/lance/issues/4758 + // TODO: Add boolean to bloom filter supported types https://github.com/lancedb/lance/issues/4757 + // [None, Some(IndexType::Bitmap), Some(IndexType::BTree), Some(IndexType::BloomFilter), Some(IndexType::ZoneMap)], + [None], + ) + .run(|ds: Dataset, original: RecordBatch| async move { + test_scan(&original, &ds).await; + test_take(&original, &ds).await; + test_filter(&original, &ds, "value").await; + test_filter(&original, &ds, "NOT value").await; + }) + .await +} + +#[tokio::test] +#[rstest::rstest] +#[case::int8(DataType::Int8)] +#[case::int16(DataType::Int16)] +#[case::int32(DataType::Int32)] +#[case::int64(DataType::Int64)] +#[case::uint8(DataType::UInt8)] +#[case::uint16(DataType::UInt16)] +#[case::uint32(DataType::UInt32)] +#[case::uint64(DataType::UInt64)] +async fn test_query_integer(#[case] data_type: DataType) { + let value_generator = match data_type { + DataType::Int8 => array::rand_primitive::(data_type), + DataType::Int16 => array::rand_primitive::(data_type), + DataType::Int32 => array::rand_primitive::(data_type), + DataType::Int64 => array::rand_primitive::(data_type), + DataType::UInt8 => array::rand_primitive::(data_type), + DataType::UInt16 => array::rand_primitive::(data_type), + DataType::UInt32 => array::rand_primitive::(data_type), + DataType::UInt64 => array::rand_primitive::(data_type), + _ => unreachable!(), + }; + + let batch = gen_batch() + .col("id", array::step::()) + .col("value", value_generator.with_random_nulls(0.1)) + .into_batch_rows(RowCount::from(60)) + .unwrap(); + DatasetTestCases::from_data(batch) + .with_index_types( + "value", + // TODO: add zone map and bloom filter once we fix https://github.com/lancedb/lance/issues/4758 + [None, Some(IndexType::Bitmap), Some(IndexType::BTree)], + ) + .run(|ds: Dataset, original: RecordBatch| async move { + test_scan(&original, &ds).await; + test_take(&original, &ds).await; + test_filter(&original, &ds, "value > 20").await; + test_filter(&original, &ds, "NOT (value > 20)").await; + test_filter(&original, &ds, "value is null").await; + test_filter(&original, &ds, "value is not null").await; + }) + .await +} + +// TODO: floats (including NaN, +/-Inf, +/-0) +// TODO: decimals +// TODO: binary +// TODO: strings (including largestrings and view) +// TODO: timestamps diff --git a/rust/lance/tests/query/vectors.rs b/rust/lance/tests/query/vectors.rs new file mode 100644 index 00000000000..9d8c640a7e9 --- /dev/null +++ b/rust/lance/tests/query/vectors.rs @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use super::{test_ann, test_scan, test_take}; +use crate::utils::DatasetTestCases; +use arrow::datatypes::{Date32Type, Float32Type, Int32Type}; +use arrow_array::RecordBatch; +use lance::Dataset; +use lance_datagen::{array, gen_batch, ArrayGeneratorExt, Dimension, RowCount}; +use lance_index::IndexType; + +fn date_as_i32(date: &str) -> i32 { + // Return as i32 days since unix epoch. + use chrono::{NaiveDate, TimeZone, Utc}; + + let parsed_date = + NaiveDate::parse_from_str(date, "%Y-%m-%d").expect("Date should be in YYYY-MM-DD format"); + + let unix_epoch = Utc.timestamp_opt(0, 0).unwrap().date_naive(); + + (parsed_date - unix_epoch).num_days() as i32 +} + +#[tokio::test] +async fn test_query_prefilter_date() { + let batch = gen_batch() + .col("id", array::step::()) + .col( + "value", + array::step_custom::(date_as_i32("2020-01-01"), 1).with_random_nulls(0.1), + ) + .col("vec", array::rand_vec::(Dimension::from(16))) + .into_batch_rows(RowCount::from(256)) + .unwrap(); + DatasetTestCases::from_data(batch) + .with_index_types("value", [None, Some(IndexType::BTree)]) + .with_index_types( + "vec", + [ + None, + Some(IndexType::IvfPq), + Some(IndexType::IvfSq), + Some(IndexType::IvfFlat), + // TODO: HNSW results are very flakey. + // Some(IndexType::IvfHnswFlat), + // Some(IndexType::IvfHnswPq), + // Some(IndexType::IvfHnswSq), + ], + ) + .run(|ds: Dataset, original: RecordBatch| async move { + test_scan(&original, &ds).await; + test_take(&original, &ds).await; + test_ann(&original, &ds, "vec", None).await; + test_ann(&original, &ds, "vec", Some("value is not null")).await; + test_ann( + &original, + &ds, + "vec", + Some("value >= DATE '2020-01-03' AND value <= DATE '2020-01-25'"), + ) + .await; + }) + .await +} diff --git a/rust/lance/tests/utils/mod.rs b/rust/lance/tests/utils/mod.rs new file mode 100644 index 00000000000..9ef9f39b10d --- /dev/null +++ b/rust/lance/tests/utils/mod.rs @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::panic::AssertUnwindSafe; +use std::sync::Arc; + +use arrow_array::{ArrayRef, Int32Array, RecordBatch}; +use futures::FutureExt; +use lance::index::vector::VectorIndexParams; +use lance::{ + dataset::{InsertBuilder, WriteParams}, + Dataset, +}; +use lance_index::scalar::ScalarIndexParams; +use lance_index::vector::hnsw::builder::HnswBuildParams; +use lance_index::vector::ivf::IvfBuildParams; +use lance_index::vector::pq::PQBuildParams; +use lance_index::vector::sq::builder::SQBuildParams; +use lance_index::{DatasetIndexExt, IndexParams, IndexType}; +use lance_linalg::distance::{DistanceType, MetricType}; + +#[derive(Clone, Copy, Debug)] +pub enum Fragmentation { + /// All data in a single file. + SingleFragment, + /// Data is spread across multiple fragments, one file per fragment. + MultiFragment, +} + +#[derive(Clone, Copy, Debug)] +pub enum DeletionState { + /// No deletions are applied. + NoDeletions, + /// Delete odd rows. + DeleteOdd, + /// Delete even rows. + DeleteEven, +} + +pub struct DatasetTestCases { + original: RecordBatch, + index_options: Vec<(String, Vec>)>, +} + +impl DatasetTestCases { + pub fn from_data(original: RecordBatch) -> Self { + Self { + original, + index_options: Vec::new(), + } + } + + pub fn with_index_types( + mut self, + column: impl Into, + index_types: impl IntoIterator>, + ) -> Self { + self.index_options + .push((column.into(), index_types.into_iter().collect())); + self + } + + fn generate_index_combinations(&self) -> Vec> { + if self.index_options.is_empty() { + return vec![vec![]]; + } + + fn generate_recursive<'a>( + options: &'a [(String, Vec>)], + current_idx: usize, + current_combination: Vec<(&'a str, IndexType)>, + results: &mut Vec>, + ) { + if current_idx == options.len() { + // Only add non-empty combinations (filter out all-None case) + if !current_combination.is_empty() { + results.push(current_combination); + } + return; + } + + let (column, index_types) = &options[current_idx]; + + // Try each index type for this column (including None) + for index_type_opt in index_types { + let mut next_combination = current_combination.clone(); + if let Some(index_type) = index_type_opt { + next_combination.push((column.as_str(), *index_type)); + } + generate_recursive(options, current_idx + 1, next_combination, results); + } + } + + let mut results = Vec::new(); + generate_recursive(&self.index_options, 0, Vec::new(), &mut results); + results + } + + pub async fn run(self, test_fn: F) -> Fut::Output + where + F: Fn(Dataset, RecordBatch) -> Fut, + Fut: std::future::Future, + { + for fragmentation in [Fragmentation::SingleFragment, Fragmentation::MultiFragment] { + for deletion in [ + DeletionState::NoDeletions, + DeletionState::DeleteOdd, + DeletionState::DeleteEven, + ] { + let index_combinations = self.generate_index_combinations(); + for indices in index_combinations { + let ds = + build_dataset(self.original.clone(), fragmentation, deletion, &indices) + .await; + let context = format!( + "fragmentation: {:?}, deletion: {:?}, index: {:?}", + fragmentation, deletion, indices + ); + // Catch unwind so we can add test context to the panic. + AssertUnwindSafe(test_fn(ds, self.original.clone())) + .catch_unwind() + .await + .unwrap_or_else(|_| panic!("Test failed for {}", context)); + } + } + } + } +} + +/// Create an in-memory dataset with the given state and data. +/// +/// The data in dataset will exactly match the `original` batch. (Extra rows are +/// created for the deleted rows created by `DeletionState`.) +async fn build_dataset( + original: RecordBatch, + fragmentation: Fragmentation, + deletion: DeletionState, + indices: &[(&str, IndexType)], +) -> Dataset { + let data_to_write = fill_deleted_rows(&original, deletion); + + let max_rows_per_file = if let Fragmentation::MultiFragment = fragmentation { + 3 + } else { + 1_000_000 + }; + + let mut ds = InsertBuilder::new("memory://") + .with_params(&WriteParams { + max_rows_per_file, + ..Default::default() + }) + .execute(vec![data_to_write]) + .await + .expect("Failed to create test dataset"); + + ds.delete("id = -1") + .await + .expect("Failed to delete filler rows (id = -1)"); + + assert_eq!(ds.count_rows(None).await.unwrap(), original.num_rows()); + + for (column, index_type) in indices.iter() { + // TODO: when possible, make indices cover a portion of rows and not be + // aligned between indices. + + // Index parameters are chosen to make search results deterministic for small + // test datasets, not for production use. + let index_params: Box = match index_type { + IndexType::BTree + | IndexType::Bitmap + | IndexType::LabelList + | IndexType::NGram + | IndexType::ZoneMap + | IndexType::Inverted + | IndexType::BloomFilter => Box::new(ScalarIndexParams::for_builtin( + (*index_type).try_into().unwrap(), + )), + IndexType::IvfFlat => { + // Use a small number of partitions for testing + Box::new(VectorIndexParams::ivf_flat(2, MetricType::L2)) + } + IndexType::IvfPq => { + // Simple PQ params for testing + Box::new(VectorIndexParams::ivf_pq(2, 8, 2, MetricType::L2, 10)) + } + IndexType::IvfSq => Box::new(VectorIndexParams::with_ivf_sq_params( + DistanceType::L2, + IvfBuildParams::new(2), + SQBuildParams::default(), + )), + IndexType::IvfHnswFlat => Box::new(VectorIndexParams::with_ivf_flat_params( + DistanceType::L2, + IvfBuildParams::new(2), + )), + IndexType::IvfHnswPq => Box::new(VectorIndexParams::with_ivf_hnsw_pq_params( + DistanceType::L2, + IvfBuildParams::new(2), + HnswBuildParams::default().ef_construction(200), + PQBuildParams::new(2, 8), + )), + IndexType::IvfHnswSq => Box::new(VectorIndexParams::with_ivf_hnsw_sq_params( + DistanceType::L2, + IvfBuildParams::new(2), + HnswBuildParams::default().ef_construction(200), + SQBuildParams::default(), + )), + _ => { + // For other index types, use default scalar params + Box::new(ScalarIndexParams::default()) + } + }; + + ds.create_index_builder(&[column], *index_type, index_params.as_ref()) + .await + .unwrap_or_else(|e| { + panic!( + "Failed to create index on column '{}' with type {:?}: {}", + column, index_type, e + ) + }); + } + + ds +} + +/// Insert filler rows into a record batch such that applying deletions to the +/// output will yield the input. For example, given the `deletions: DeletionState::DeleteOdd` +/// and the table: +/// +/// ``` +/// id | value +/// 1 | "a" +/// 2 | "b" +/// ``` +/// +/// Produce: +/// +/// ``` +/// id | value +/// -1 | "a" (filler row) +/// 1 | "a" +/// -1 | "a" +/// 2 | "b" +/// ``` +/// +/// The filler row will have the same values as the original row, but with a special +/// identifier (e.g., -1) to indicate that it is a filler row. +fn fill_deleted_rows(batch: &RecordBatch, deletions: DeletionState) -> RecordBatch { + // Early return for no deletions + if let DeletionState::NoDeletions = deletions { + return batch.clone(); + } + + // Create a filler batch by taking the first row and replacing id with -1 + let schema = batch.schema(); + let mut filler_columns: Vec = Vec::new(); + + for (i, field) in schema.fields().iter().enumerate() { + if field.name() == "id" { + // Create an array with a single -1 value + filler_columns.push(Arc::new(Int32Array::from(vec![-1]))); + } else { + // Take the first value from the original column + let original_column = batch.column(i); + let sliced = original_column.slice(0, 1); + filler_columns.push(sliced); + } + } + + let filler_batch = RecordBatch::try_new(schema.clone(), filler_columns).unwrap(); + + // Create an array of filler batches, one for each row that will be deleted + let num_rows = batch.num_rows(); + let filler_batches = vec![filler_batch; num_rows]; + + // Concatenate all filler batches into one + let all_fillers = arrow_select::concat::concat_batches(&schema, &filler_batches).unwrap(); + + // Create indices for interleaving based on the deletion pattern + // Format: (batch_index, row_index) where batch_index 0 = original, 1 = fillers + let mut indices: Vec<(usize, usize)> = Vec::new(); + + match deletions { + DeletionState::DeleteOdd => { + // Pattern: filler, original[0], filler, original[1], ... + for i in 0..num_rows { + indices.push((1, i)); // filler batch, row i + indices.push((0, i)); // original batch, row i + } + } + DeletionState::DeleteEven => { + // Pattern: original[0], filler, original[1], filler, ... + for i in 0..num_rows { + indices.push((0, i)); // original batch, row i + indices.push((1, i)); // filler batch, row i + } + } + DeletionState::NoDeletions => unreachable!(), + } + + // Use interleave to reorder according to our indices + arrow::compute::interleave_record_batch(&[batch, &all_fillers], &indices).unwrap() +}