diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index 8c6bc402544..9400b1111c6 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -5,14 +5,16 @@ use std::sync::Arc; use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; -use crate::traits::{import_vec_from_method, import_vec_to_rust}; +use crate::traits::{export_vec, import_vec_from_method, import_vec_to_rust, IntoJava}; use arrow::array::Float32Array; use arrow::{ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream}; use arrow_schema::SchemaRef; -use jni::objects::{JObject, JString}; +use jni::objects::{JObject, JString, JValueGen}; use jni::sys::{jboolean, jint, JNI_TRUE}; use jni::{sys::jlong, JNIEnv}; -use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner}; +use lance::dataset::scanner::{ + ColumnOrdering, DatasetRecordBatchStream, Scanner, Split, SplitFragment, SplitOptions, +}; use lance_index::scalar::inverted::query::{ BooleanQuery as FtsBooleanQuery, BoostQuery as FtsBoostQuery, FtsQuery, MatchQuery as FtsMatchQuery, MultiMatchQuery as FtsMultiMatchQuery, Occur as FtsOccur, @@ -24,7 +26,6 @@ use lance_linalg::distance::DistanceType; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, - traits::IntoJava, RT, }; @@ -56,6 +57,11 @@ impl BlockingScanner { let res = RT.block_on(self.inner.count_rows())?; Ok(res) } + + pub fn plan_splits(&self, options: Option) -> Result> { + let res = RT.block_on(self.inner.plan_splits(options))?; + Ok(res) + } } fn build_full_text_search_query<'a>(env: &mut JNIEnv<'a>, java_obj: JObject) -> Result { @@ -481,3 +487,80 @@ fn inner_count_rows(env: &mut JNIEnv, j_scanner: JObject) -> Result { unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?; scanner_guard.count_rows() } + +#[no_mangle] +pub extern "system" fn Java_org_lance_ipc_LanceScanner_nativePlanSplits<'local>( + mut env: JNIEnv<'local>, + j_scanner: JObject, + options_obj: JObject, // Optional +) -> JObject<'local> { + ok_or_throw!(env, inner_plan_splits(&mut env, j_scanner, options_obj)) +} + +fn inner_plan_splits<'local>( + env: &mut JNIEnv<'local>, + j_scanner: JObject, + options_obj: JObject, +) -> Result> { + let options = extract_split_options(env, &options_obj)?; + let splits = { + let scanner_guard = + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?; + scanner_guard.plan_splits(options)? + }; + export_vec(env, &splits) +} + +fn extract_split_options(env: &mut JNIEnv, options_obj: &JObject) -> Result> { + if options_obj.is_null() { + return Ok(None); + } + + let is_present = env.call_method(options_obj, "isPresent", "()Z", &[])?.z()?; + + if !is_present { + return Ok(None); + } + + let options_inner = env + .call_method(options_obj, "get", "()Ljava/lang/Object;", &[])? + .l()?; + + let max_size_bytes = env.get_optional_i64_from_method(&options_inner, "getMaxSizeBytes")?; + let max_row_count = env.get_optional_i64_from_method(&options_inner, "getMaxRowCount")?; + + Ok(Some(SplitOptions { + max_size_bytes: max_size_bytes.map(|v| v as usize), + max_row_count: max_row_count.map(|v| v as usize), + })) +} + +const SPLIT_CLASS: &str = "org/lance/ipc/Split"; +const SPLIT_CONSTRUCTOR_SIG: &str = "(Ljava/util/List;)V"; +const SPLIT_FRAGMENT_CLASS: &str = "org/lance/ipc/SplitFragment"; +const SPLIT_FRAGMENT_CONSTRUCTOR_SIG: &str = "(Lorg/lance/FragmentMetadata;J)V"; + +impl IntoJava for &SplitFragment { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let fragment = self.fragment.into_java(env)?; + Ok(env.new_object( + SPLIT_FRAGMENT_CLASS, + SPLIT_FRAGMENT_CONSTRUCTOR_SIG, + &[ + JValueGen::Object(&fragment), + JValueGen::Long(self.max_row_count as i64), + ], + )?) + } +} + +impl IntoJava for &Split { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let fragments = export_vec(env, &self.fragments)?; + Ok(env.new_object( + SPLIT_CLASS, + SPLIT_CONSTRUCTOR_SIG, + &[JValueGen::Object(&fragments)], + )?) + } +} diff --git a/java/src/main/java/org/lance/ipc/LanceScanner.java b/java/src/main/java/org/lance/ipc/LanceScanner.java index 804b7ea22f3..caab3ee74c3 100644 --- a/java/src/main/java/org/lance/ipc/LanceScanner.java +++ b/java/src/main/java/org/lance/ipc/LanceScanner.java @@ -164,4 +164,38 @@ public long countRows() { } private native long nativeCountRows(); + + /** + * Plan splits for parallel scanning of the dataset. + * + *

Splits can be used to distribute scanning work across multiple workers or threads. Each + * split contains one or more fragments that can be scanned together. + * + * @return a list of splits for parallel scanning + */ + public List planSplits() { + return planSplits(Optional.empty()); + } + + /** + * Plan splits for parallel scanning of the dataset with custom options. + * + *

Splits can be used to distribute scanning work across multiple workers or threads. Each + * split contains one or more fragments that can be scanned together. + * + * @param options options for configuring split generation + * @return a list of splits for parallel scanning + */ + public List planSplits(SplitOptions options) { + return planSplits(Optional.ofNullable(options)); + } + + private List planSplits(Optional options) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeScannerHandle != 0, "Scanner is closed"); + return nativePlanSplits(options); + } + } + + private native List nativePlanSplits(Optional options); } diff --git a/java/src/main/java/org/lance/ipc/Split.java b/java/src/main/java/org/lance/ipc/Split.java new file mode 100644 index 00000000000..b053c59536e --- /dev/null +++ b/java/src/main/java/org/lance/ipc/Split.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import com.google.common.base.MoreObjects; + +import java.io.Serializable; +import java.util.List; +import java.util.Objects; + +/** + * Represents a split for parallel scanning of fragments. + * + *

A split contains one or more fragments that can be scanned together. Splits can be used to + * distribute scanning work across multiple workers or threads. + */ +public class Split implements Serializable { + private static final long serialVersionUID = 1L; + private final List fragments; + + /** + * Creates a new Split. + * + * @param fragments the list of fragments in this split + */ + public Split(List fragments) { + this.fragments = fragments; + } + + /** + * Returns the list of fragments in this split. + * + * @return the list of split fragments + */ + public List getFragments() { + return fragments; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Split split = (Split) o; + return Objects.equals(fragments, split.fragments); + } + + @Override + public int hashCode() { + return Objects.hash(fragments); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("fragments", fragments).toString(); + } +} diff --git a/java/src/main/java/org/lance/ipc/SplitFragment.java b/java/src/main/java/org/lance/ipc/SplitFragment.java new file mode 100644 index 00000000000..63ec1c981b2 --- /dev/null +++ b/java/src/main/java/org/lance/ipc/SplitFragment.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import org.lance.FragmentMetadata; + +import com.google.common.base.MoreObjects; + +import java.io.Serializable; +import java.util.Objects; + +/** + * A fragment within a {@link Split}, along with metadata about the expected number of rows that + * will be scanned from it. + */ +public class SplitFragment implements Serializable { + private static final long serialVersionUID = 1L; + private final FragmentMetadata fragment; + private final long maxRowCount; + + /** + * Creates a new SplitFragment. + * + * @param fragment the fragment metadata + * @param maxRowCount an upper bound on the number of rows that will be read from this fragment + * after applying any filters or index pruning + */ + public SplitFragment(FragmentMetadata fragment, long maxRowCount) { + this.fragment = fragment; + this.maxRowCount = maxRowCount; + } + + /** + * Returns the fragment metadata. + * + * @return the fragment metadata + */ + public FragmentMetadata getFragment() { + return fragment; + } + + /** + * Returns an upper bound on the number of rows that will be read from this fragment after + * applying any filters or index pruning. + * + * @return the maximum row count + */ + public long getMaxRowCount() { + return maxRowCount; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SplitFragment that = (SplitFragment) o; + return maxRowCount == that.maxRowCount && Objects.equals(fragment, that.fragment); + } + + @Override + public int hashCode() { + return Objects.hash(fragment, maxRowCount); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("fragment", fragment) + .add("maxRowCount", maxRowCount) + .toString(); + } +} diff --git a/java/src/main/java/org/lance/ipc/SplitOptions.java b/java/src/main/java/org/lance/ipc/SplitOptions.java new file mode 100644 index 00000000000..bd9645bd33b --- /dev/null +++ b/java/src/main/java/org/lance/ipc/SplitOptions.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import java.util.Optional; + +/** + * Options for configuring split generation in a scanner. + * + *

This class allows specifying constraints on the maximum size and row count for splits. Both + * fields are optional; if neither is set, default behavior will be used. + */ +public class SplitOptions { + private final Optional maxSizeBytes; + private final Optional maxRowCount; + + private SplitOptions(Builder builder) { + this.maxSizeBytes = builder.maxSizeBytes; + this.maxRowCount = builder.maxRowCount; + } + + /** + * Returns the maximum size in bytes per split. + * + * @return the maximum size in bytes, or empty if not set + */ + public Optional getMaxSizeBytes() { + return maxSizeBytes; + } + + /** + * Returns the maximum number of rows per split. + * + * @return the maximum row count, or empty if not set + */ + public Optional getMaxRowCount() { + return maxRowCount; + } + + /** Builder for {@link SplitOptions}. */ + public static class Builder { + private Optional maxSizeBytes = Optional.empty(); + private Optional maxRowCount = Optional.empty(); + + /** + * Sets the maximum size in bytes per split. + * + *

The scanner estimates the row size from the output schema and calculates how many rows fit + * within this budget. + * + * @param maxSizeBytes the maximum size in bytes + * @return this builder + */ + public Builder maxSizeBytes(long maxSizeBytes) { + this.maxSizeBytes = Optional.of(maxSizeBytes); + return this; + } + + /** + * Sets the maximum number of rows per split. + * + * @param maxRowCount the maximum number of rows + * @return this builder + */ + public Builder maxRowCount(long maxRowCount) { + this.maxRowCount = Optional.of(maxRowCount); + return this; + } + + /** + * Builds a new {@link SplitOptions} instance. + * + * @return a new SplitOptions instance + */ + public SplitOptions build() { + return new SplitOptions(this); + } + } +} diff --git a/java/src/test/java/org/lance/ScannerTest.java b/java/src/test/java/org/lance/ScannerTest.java index 9fe844e9334..b2991d956e8 100644 --- a/java/src/test/java/org/lance/ScannerTest.java +++ b/java/src/test/java/org/lance/ScannerTest.java @@ -16,6 +16,9 @@ import org.lance.ipc.ColumnOrdering; import org.lance.ipc.LanceScanner; import org.lance.ipc.ScanOptions; +import org.lance.ipc.Split; +import org.lance.ipc.SplitFragment; +import org.lance.ipc.SplitOptions; import org.apache.arrow.dataset.scanner.Scanner; import org.apache.arrow.memory.BufferAllocator; @@ -554,4 +557,82 @@ private void validScanResult(Dataset dataset, int fragmentId, int rowCount) thro } } } + + @Test + void testPlanSplits(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("plan_splits").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 100; + try (Dataset dataset = testDataset.write(1, totalRows)) { + try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().build())) { + // Test planSplits without options + List splits = scanner.planSplits(); + assertFalse(splits.isEmpty(), "Should return at least one split"); + + // Verify each split has fragments with valid metadata + for (Split split : splits) { + List fragments = split.getFragments(); + assertFalse(fragments.isEmpty(), "Each split should have at least one fragment"); + for (SplitFragment sf : fragments) { + FragmentMetadata fm = sf.getFragment(); + assertTrue(fm.getId() >= 0, "Fragment ID should be non-negative"); + assertTrue(sf.getMaxRowCount() > 0, "Max row count should be positive"); + } + } + } + } + } + } + + @Test + void testPlanSplitsWithOptions(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("plan_splits_options").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 100; + try (Dataset dataset = testDataset.write(1, totalRows)) { + try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().build())) { + // Test planSplits with maxRowCount option + SplitOptions options = new SplitOptions.Builder().maxRowCount(20).build(); + List splits = scanner.planSplits(options); + assertFalse(splits.isEmpty(), "Should return at least one split"); + + // With max 20 rows per split and 100 total rows, + // we should have multiple splits + assertTrue(splits.size() >= 1, "Should have at least one split"); + + // Verify each split fragment respects the row count + for (Split split : splits) { + for (SplitFragment sf : split.getFragments()) { + assertTrue(sf.getMaxRowCount() > 0, "Max row count should be positive"); + } + } + } + } + } + } + + @Test + void testPlanSplitsWithMaxSizeBytes(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("plan_splits_max_bytes").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 100; + try (Dataset dataset = testDataset.write(1, totalRows)) { + try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().build())) { + // Test planSplits with maxSizeBytes option + SplitOptions options = new SplitOptions.Builder().maxSizeBytes(1024).build(); + List splits = scanner.planSplits(options); + assertFalse(splits.isEmpty(), "Should return at least one split"); + } + } + } + } } diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 7b19ad72c80..a2ee300a1f2 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -5189,6 +5189,41 @@ def analyze_plan(self) -> str: return self._scanner.analyze_plan() + def plan_splits( + self, + max_size_bytes: Optional[int] = None, + max_row_count: Optional[int] = None, + ) -> List[List["FragmentMetadata"]]: + """Plan splits for distributed scanning. + + This method analyzes the scanner's filter and uses indices to determine + which fragments need to be scanned and approximately how many rows each + fragment will return. It then groups fragments into splits that can be + processed independently. + + The scanner estimates the size of each row based on the output schema + projection and uses that to determine how many rows fit within the + target split size. + + Parameters + ---------- + max_size_bytes : int, optional + The target maximum size in bytes for each split. Defaults to 128MB. + max_row_count : int, optional + The maximum number of rows for each split. If specified, this takes + precedence over max_size_bytes. + + Returns + ------- + List[List[FragmentMetadata]] + A list of splits, where each split is a list of FragmentMetadata objects. + Each split can be processed independently for distributed scanning. + """ + + return self._scanner.plan_splits( + max_size_bytes=max_size_bytes, max_row_count=max_row_count + ) + class DatasetOptimizer: def __init__(self, dataset: LanceDataset): diff --git a/python/src/scanner.rs b/python/src/scanner.rs index 1e85af6711f..b4e19e8e10d 100644 --- a/python/src/scanner.rs +++ b/python/src/scanner.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::pyarrow::*; use arrow_array::RecordBatchReader; -use lance::dataset::scanner::ExecutionSummaryCounts; +use lance::dataset::scanner::{ExecutionSummaryCounts, SplitOptions}; use pyo3::prelude::*; use pyo3::pyclass; @@ -30,6 +30,7 @@ use pyo3::exceptions::PyValueError; use crate::reader::LanceReader; use crate::rt; use crate::schema::logical_arrow_schema; +use crate::utils::PyLance; /// This will be wrapped by a python class to provide /// additional functionality @@ -150,4 +151,37 @@ impl Scanner { Ok(PyArrowType(Box::new(reader))) } + + #[pyo3(signature = (max_size_bytes=None, max_row_count=None))] + fn plan_splits<'py>( + self_: PyRef<'py, Self>, + max_size_bytes: Option, + max_row_count: Option, + ) -> PyResult>>> { + let scanner = self_.scanner.clone(); + let options = if max_size_bytes.is_some() || max_row_count.is_some() { + Some(SplitOptions { + max_size_bytes, + max_row_count, + }) + } else { + None + }; + let splits = rt() + .spawn(Some(self_.py()), async move { + scanner.plan_splits(options).await + })? + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + splits + .into_iter() + .map(|split| { + split + .fragments + .into_iter() + .map(|sf| PyLance(sf.fragment).into_pyobject(self_.py())) + .collect::, _>>() + }) + .collect::, _>>() + } } diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index a1f56d48a84..0c21027b38a 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -657,6 +657,10 @@ impl RowAddrTreeMap { }), }) } + + pub fn fragments(&self) -> Vec { + self.inner.keys().cloned().collect() + } } impl std::ops::BitOr for RowAddrTreeMap { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index a0812d6caf4..27b2579c6ac 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::{HashMap, HashSet}; use std::ops::Range; use std::pin::Pin; use std::sync::{Arc, LazyLock}; @@ -600,6 +601,53 @@ pub struct Scanner { autoproject_scoring_columns: bool, } +/// Represents a split for parallel scanning of fragments +/// +/// A split contains one or more fragments that can be scanned together. +/// Splits can be used to distribute scanning work across multiple workers or threads. +#[derive(Debug, Clone)] +pub struct Split { + pub fragments: Vec, +} + +/// A fragment within a [`Split`], along with metadata about the expected +/// number of rows that will be scanned from it. +#[derive(Debug, Clone)] +pub struct SplitFragment { + /// The fragment to scan. + pub fragment: Fragment, + /// An upper bound on the number of rows that will be read from this + /// fragment after applying any filters or index pruning. + pub max_row_count: usize, +} + +/// Strategy for packing fragments into splits. +#[derive(Debug, Clone)] +pub enum SplitPackStrategy { + /// Target a maximum size in bytes per split. The scanner estimates the row + /// size from the output schema and calculates how many rows fit within this + /// budget. + MaxSizeBytes(usize), + /// Target a maximum number of rows per split. + MaxRowCount(usize), +} + +/// Options for configuring split generation. +/// +/// This struct allows specifying constraints on the maximum size and row count +/// for splits. Both fields are optional; if neither is set, default behavior +/// will be used. +#[derive(Debug, Clone, Default)] +pub struct SplitOptions { + /// Maximum size in bytes per split. + /// + /// The scanner estimates the row size from the output schema and calculates + /// how many rows fit within this budget. + pub max_size_bytes: Option, + /// Maximum number of rows per split. + pub max_row_count: Option, +} + /// Represents a user-requested take operation #[derive(Debug, Clone)] pub enum TakeOperation { @@ -858,6 +906,217 @@ impl Scanner { self.fragments.is_some() } + /// Plan splits for distributed or parallel scanning of the dataset. + /// + /// This method analyzes the fragments to be scanned and groups them into [`Split`]s + /// that can be processed independently by multiple workers or threads. It uses a + /// bin-packing algorithm to create balanced splits based on estimated row counts. + /// + /// The maximum number of rows per split is determined by taking the minimum of: + /// - `max_row_count` from [`SplitOptions`] (if provided) + /// - The estimated row count from `max_size_bytes` in [`SplitOptions`] (if provided) + /// - If neither is provided, uses a default split size of 128MB to estimate row count + pub async fn plan_splits(&self, options: Option) -> Result> { + // Collect initial set of fragments to scan + let fragments = if let Some(fragments) = self.fragments.as_ref() { + Arc::new(fragments.clone()) + } else { + Arc::new(self.dataset.fragments().as_ref().clone()) + }; + + // Use indices to prune fragments and compute max row counts per fragment + let mut frag_max_row_counts: HashMap = HashMap::new(); + let mut covered_frag_ids: HashSet<_> = HashSet::new(); + + let filter_plan = self.create_filter_plan(true).await?; + if let Some(index_expr) = filter_plan.expr_filter_plan.index_query.as_ref() { + // Partition fragments by coverage of the index expression + let (covered_frags, _) = self + .partition_frags_by_coverage(index_expr, fragments.clone()) + .await?; + covered_frags.iter().for_each(|frag| { + covered_frag_ids.insert(frag.id); + }); + + // Evaluate the index expression to retrieve a bitmask of matching rows + let expr_result = index_expr + .evaluate(self.dataset.as_ref(), &NoOpMetricsCollector) + .await?; + match expr_result { + IndexExprResult::Exact(mask) | IndexExprResult::AtMost(mask) => { + match mask { + RowAddrMask::AllowList(bitmap) => { + // Iterate over covered fragments and update row counts + let allow_frag_ids: HashSet = + bitmap.fragments().into_iter().collect(); + + for frag in &covered_frags { + if allow_frag_ids.contains(&(frag.id as u32)) { + let row_count = match bitmap.get_fragment_bitmap(frag.id as u32) + { + Some(frag_bitmap) => frag_bitmap.len() as usize, + None => { + // Since we know `frag.id` is in the bitmap, this `None + // means the fragment bitmap is full. Use the total + // number of rows in the fragment. + frag.num_rows().unwrap_or(usize::MAX) + } + }; + + frag_max_row_counts.insert(frag.id, row_count); + } else { + // PRUNE fragment since no rows match + } + } + } + RowAddrMask::BlockList(bitmap) => { + // Iterate over covered fragments and update row counts + let blocked_frag_ids: HashSet = + bitmap.fragments().into_iter().collect(); + + for frag in &covered_frags { + if !blocked_frag_ids.contains(&(frag.id as u32)) { + // Fragment is not blocked, so all rows are allowed + frag_max_row_counts + .insert(frag.id, frag.num_rows().unwrap_or(usize::MAX)); + } else { + match bitmap.get_fragment_bitmap(frag.id as u32) { + Some(frag_bitmap) => { + let blocked_row_count = frag_bitmap.len() as usize; + let row_count = match frag.num_rows() { + Some(row_count) => row_count - blocked_row_count, + None => usize::MAX, + }; + frag_max_row_counts.insert(frag.id, row_count); + } + None => { + // PRUNE fragment since no rows match + // Since we know `frag.id` is in the bitmap, this `None + // means the fragment bitmap is full. + } + } + } + } + } + } + } + IndexExprResult::AtLeast(_) => { + // In the `AtLeast` case some of the rows in the block list may be false + // positives. Therefore, we can not prune any fragments as we can not guarantee + // any fragments will not have matching rows. + } + } + } + + // Estimate row counts for fragments not covered by indices. + fragments + .iter() + .filter(|frag| !covered_frag_ids.contains(&frag.id)) + .for_each(|frag| { + // Estimate the number of rows in the fragment that satisfy the filter + let max_row_count = match frag.num_rows() { + Some(count) => count, + None => usize::MAX, + }; + frag_max_row_counts.insert(frag.id, max_row_count); + }); + + // Bin pack fragments into splits for parallel processing + const DEFAULT_SPLIT_SIZE: usize = 128 * 1024 * 1024; + const DEFAULT_VARIABLE_FIELD_SIZE: usize = 64; + + let options = options.unwrap_or_default(); + + // Helper to estimate row count from a byte size + let estimate_rows_from_bytes = |max_bytes: usize| -> Result { + let output_schema = self.projection_plan.output_schema()?; + let estimated_row_size: usize = output_schema + .fields() + .iter() + .map(|f| { + f.data_type() + .byte_width_opt() + .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) + }) + .sum(); + Ok(max_bytes / estimated_row_size.max(1)) + }; + + let max_rows_per_split = match (options.max_row_count, options.max_size_bytes) { + (Some(max_rows), Some(max_bytes)) => { + // Use the minimum of both constraints + let rows_from_bytes = estimate_rows_from_bytes(max_bytes)?; + max_rows.min(rows_from_bytes) + } + (Some(max_rows), None) => max_rows, + (None, Some(max_bytes)) => estimate_rows_from_bytes(max_bytes)?, + (None, None) => estimate_rows_from_bytes(DEFAULT_SPLIT_SIZE)?, + }; + + let bins = Self::bin_pack(&frag_max_row_counts, max_rows_per_split); + + // Convert bins to splits + let fragment_map: HashMap = fragments.iter().map(|f| (f.id, f)).collect(); + let splits = bins + .into_iter() + .map(|bin| { + let frags = bin + .into_iter() + .filter_map(|id| { + fragment_map.get(&id).map(|&f| SplitFragment { + fragment: f.clone(), + max_row_count: frag_max_row_counts + .get(&id) + .copied() + .unwrap_or(usize::MAX), + }) + }) + .collect(); + Split { fragments: frags } + }) + .collect(); + + Ok(splits) + } + + /// Packs IDs into bins where each bin's total count is less than `maximum_count`. + /// + /// Uses a first-fit decreasing algorithm: items are sorted by count in descending + /// order, then each item is placed in the first bin that has room for it. + fn bin_pack(items: &HashMap, maximum_count: usize) -> Vec> { + // Convert to vec and sort by count descending for better packing + let mut items: Vec<(u64, usize)> = items.iter().map(|(&k, &v)| (k, v)).collect(); + items.sort_by(|a, b| b.1.cmp(&a.1)); + + let mut bins: Vec<(Vec, usize)> = Vec::new(); // (ids, current_count) + + for (id, count) in items { + // Items that exceed the maximum get their own bin + if count > maximum_count { + bins.push((vec![id], count)); + continue; + } + + // Find first bin with enough remaining capacity + let mut placed = false; + for (bin_ids, bin_count) in &mut bins { + if *bin_count + count <= maximum_count { + bin_ids.push(id); + *bin_count += count; + placed = true; + break; + } + } + + // Create new bin if no existing bin has room + if !placed { + bins.push((vec![id], count)); + } + } + + bins.into_iter().map(|(ids, _)| ids).collect() + } + /// Empty Projection (useful for count queries) /// /// The row_address will be scanned (no I/O required) but not included in the output @@ -9083,4 +9342,189 @@ mod test { runtime.handle().metrics().num_alive_tasks() ); } + + #[test] + fn test_bin_pack_empty() { + let items: HashMap = HashMap::new(); + let bins = Scanner::bin_pack(&items, 100); + assert!(bins.is_empty()); + } + + #[test] + fn test_bin_pack_mixed_sizes() { + // maximum = 100 + // Items: 70, 50, 40, 20, 10 + // Sorted descending: 70, 50, 40, 20, 10 + // Bin 1: 70, try 50 (70+50=120 > 100, no), try 40 (70+40=110 > 100, no), + // try 20 (70+20=90 <= 100, yes), try 10 (90+10=100 <= 100, yes) -> [70, 20, 10] + // Bin 2: 50, try 40 (50+40=90 <= 100, yes) -> [50, 40] + let items = HashMap::from([(1, 70), (2, 50), (3, 40), (4, 20), (5, 10)]); + let bins = Scanner::bin_pack(&items, 100); + + // Verify total items across all bins + let total_items: usize = bins.iter().map(|b| b.len()).sum(); + assert_eq!(total_items, 5); + + // Each bin's total count should be <= maximum (or a single oversized item) + let item_map: HashMap = [(1, 70), (2, 50), (3, 40), (4, 20), (5, 10)].into(); + for bin in &bins { + let bin_total: usize = bin.iter().map(|id| item_map[id]).sum(); + assert!(bin_total <= 100, "bin total {} exceeds maximum", bin_total); + } + } + + #[test] + fn test_bin_pack_oversized_items_get_own_bin() { + let items = HashMap::from([(1, 200), (2, 150), (3, 30)]); + let bins = Scanner::bin_pack(&items, 100); + + // Oversized items (200, 150) each get their own bin + // Item 30 could fit in a new bin + assert_eq!(bins.len(), 3); + + // Each bin should have exactly one item + for bin in &bins { + assert_eq!(bin.len(), 1); + } + } + + #[tokio::test] + async fn test_plan_splits_basic() { + // Create a dataset with 4 fragments of 100 rows each, single i32 column (4 bytes) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + let splits = dataset.scan().plan_splits(None).await.unwrap(); + + // Default split size is 128MB, each fragment has 100 rows of 4 bytes = 400 bytes + // max_rows_per_split = 128*1024*1024 / 4 = 33554432 + // All 400 rows fit in one split + assert_eq!(splits.len(), 1); + assert_eq!(splits[0].fragments.len(), 4); + } + + #[tokio::test] + async fn test_plan_splits_with_small_split_size() { + // Create 4 fragments of 100 rows, single i32 column (4 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Set split size to 200 bytes -> max_rows = 200/4 = 50 rows per split + // Each fragment has 100 rows >= 50, so each gets its own bin + let options = SplitOptions { + max_size_bytes: Some(200), + ..Default::default() + }; + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + + assert_eq!(splits.len(), 4); + for split in &splits { + assert_eq!(split.fragments.len(), 1); + } + } + + #[tokio::test] + async fn test_plan_splits_grouping() { + // Create 4 fragments of 50 rows, single i32 column (4 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(50)) + .await + .unwrap(); + + // Set split size to 400 bytes -> max_rows = 400/4 = 100 rows per split + // Each fragment has 50 rows, so two fragments can fit per split (50+50=100, not < 100) + // Actually 50+50=100 is NOT < 100, so each fragment gets its own bin + let options = SplitOptions { + max_size_bytes: Some(400), + ..Default::default() + }; + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 4); + + // With 404 bytes: max_rows = 404/4 = 101 rows per split + // Each fragment has 50 rows, 50+50=100 < 101, so two fragments per split + let options = SplitOptions { + max_size_bytes: Some(404), + ..Default::default() + }; + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 2); + for split in &splits { + assert_eq!(split.fragments.len(), 2); + } + } + + #[tokio::test] + async fn test_plan_splits_with_projection() { + // Create dataset with two i32 columns (8 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("a", array::step::()) + .col("b", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Full projection: 8 bytes per row, max_rows = 800/8 = 100 + // Each fragment has 100 rows >= 100, so each gets its own bin + let options = SplitOptions { + max_size_bytes: Some(800), + ..Default::default() + }; + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 4); + + // Single column projection: 4 bytes per row, max_rows = 800/4 = 200 + // Each fragment has 100 rows, 100+100=200 NOT < 200, so each gets its own bin + let mut scanner = dataset.scan(); + scanner.project(&["a"]).unwrap(); + let splits = scanner + .plan_splits(Some(SplitOptions { + max_size_bytes: Some(800), + ..Default::default() + })) + .await + .unwrap(); + assert_eq!(splits.len(), 4); + + // Single column projection with larger budget: max_rows = 804/4 = 201 + // 100 + 100 = 200 < 201, so two fragments per split + let mut scanner = dataset.scan(); + scanner.project(&["a"]).unwrap(); + let splits = scanner + .plan_splits(Some(SplitOptions { + max_size_bytes: Some(804), + ..Default::default() + })) + .await + .unwrap(); + assert_eq!(splits.len(), 2); + } + + #[tokio::test] + async fn test_plan_splits_with_fragments() { + // Create dataset with 4 fragments + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Only scan 2 specific fragments + let frags: Vec<_> = dataset.fragments()[..2].to_vec(); + + let mut scanner = dataset.scan(); + scanner.with_fragments(frags); + + // Large split size so everything fits in one split + let splits = scanner.plan_splits(None).await.unwrap(); + assert_eq!(splits.len(), 1); + assert_eq!(splits[0].fragments.len(), 2); + } }