From 046fe6193ac92d9af5878ddb3a984690bfd1af3a Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 22 Jan 2026 12:38:12 -0600 Subject: [PATCH 1/6] initial pass Signed-off-by: Daniel Rammer --- rust/lance-core/src/utils/mask.rs | 16 ++- rust/lance/src/dataset/scanner.rs | 183 +++++++++++++++++++++++++++++- 2 files changed, 194 insertions(+), 5 deletions(-) diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index a1f56d48a84..30a75db5060 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -333,6 +333,13 @@ impl RowAddrSelection { res } } + + pub fn len(&self) -> Option { + match self { + Self::Full => None, + Self::Partial(bitmap) => Some(bitmap.len()), + } + } } impl RowSetOps for RowAddrTreeMap { @@ -345,10 +352,7 @@ impl RowSetOps for RowAddrTreeMap { fn len(&self) -> Option { self.inner .values() - .map(|row_addr_selection| match row_addr_selection { - RowAddrSelection::Full => None, - RowAddrSelection::Partial(indices) => Some(indices.len()), - }) + .map(|row_addr_selection| row_addr_selection.len()) .try_fold(0_u64, |acc, next| next.map(|next| next + acc)) } @@ -657,6 +661,10 @@ impl RowAddrTreeMap { }), }) } + + pub fn fragments(&self) -> Vec { + self.inner.keys().cloned().collect() + } } impl std::ops::BitOr for RowAddrTreeMap { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index a0812d6caf4..ae4555d6626 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1,10 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::{HashMap, HashSet}; use std::ops::Range; use std::pin::Pin; use std::sync::{Arc, LazyLock}; use std::task::{Context, Poll}; +use std::usize; use arrow::array::AsArray; use arrow_array::{Array, Float32Array, Int64Array, RecordBatch}; @@ -50,7 +52,7 @@ use lance_core::datatypes::{ }; use lance_core::error::LanceOptionExt; use lance_core::utils::address::RowAddress; -use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap}; +use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap, RowSetOps}; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{ROW_ADDR, ROW_ID, ROW_OFFSET}; use lance_datafusion::exec::{ @@ -600,6 +602,32 @@ pub struct Scanner { autoproject_scoring_columns: bool, } +/// Represents a split for parallel scanning of fragments +/// +/// A split contains one or more fragments that can be scanned together. +/// Splits can be used to distribute scanning work across multiple workers or threads. +#[derive(Debug, Clone)] +pub struct Split { + pub fragments: Vec, +} + +/// Options for configuring how splits are created. +pub struct SplitOptions { + /// A suggested number of rows per split. The scanner will attempt to create splits + /// with approximately this many rows, but the actual number may vary depending on + /// the sizes of the fragments. + /// For example, if a single fragment has more rows than this value, we will not split + /// the fragment, but include it in a single split. + max_rows_per_split: Option, +} + +impl SplitOptions { + /// Create a new SplitOptions with the given max_rows_per_split. + pub fn new(max_rows_per_split: Option) -> Self { + Self { max_rows_per_split } + } +} + /// Represents a user-requested take operation #[derive(Debug, Clone)] pub enum TakeOperation { @@ -858,6 +886,159 @@ impl Scanner { self.fragments.is_some() } + // TODO @hamersaw - docs + pub async fn plan_splits(&self, options: Option) -> Result> { + // Collect initial set of fragments to scan + let fragments = if let Some(fragments) = self.fragments.as_ref() { + Arc::new(fragments.clone()) + } else { + Arc::new(self.dataset.fragments().as_ref().clone()) + }; + + // Use indices to prune fragments + let mut frag_max_row_counts: HashMap = HashMap::new(); + let mut covered_frag_ids: HashSet<_> = HashSet::new(); + + let filter_plan = self.create_filter_plan(true).await?; + if let Some(index_expr) = filter_plan.expr_filter_plan.index_query.as_ref() { + // Partition fragments by coverage of the index expression + let (covered_frags, _) = self + .partition_frags_by_coverage(index_expr, fragments.clone()) + .await?; + covered_frags.iter().for_each(|frag| { + covered_frag_ids.insert(frag.id); + }); + + // Evaluate the index expression to retrieve a bitmask of matching rows + let expr_result = index_expr + .evaluate(self.dataset.as_ref(), &NoOpMetricsCollector) + .await?; + match expr_result { + IndexExprResult::Exact(mask) | IndexExprResult::AtMost(mask) => { + match mask { + RowAddrMask::AllowList(bitmap) => { + // Iterate over covered fragments and update row counts + let allow_frag_ids: HashSet = + bitmap.fragments().into_iter().collect(); + + for frag in &covered_frags { + if allow_frag_ids.contains(&(frag.id as u32)) { + let row_count = match bitmap.get_fragment_bitmap(frag.id as u32) + { + Some(frag_bitmap) => frag_bitmap.len() as usize, + None => { + // Since we know `bitmap.contains(frag.id as u32)` is + // true, this `None` means the fragment bitmap is full. + // Use the total number of rows in the fragment. + frag.num_rows().unwrap_or(usize::MAX) + } + }; + + frag_max_row_counts.insert(frag.id, row_count); + } else { + // PRUNE fragment since no rows match + } + } + } + RowAddrMask::BlockList(bitmap) => { + let _blocked_frag_ids: HashSet = + bitmap.fragments().into_iter().collect(); + + // TODO @hamersaw - figure out how to handle block list pruning correctly + /*for frag in &covered_frags { + if !bitmap.contains(frag.id as u32) { + // All rows in the fragment match + frag_max_row_counts.insert(frag.id, max_row_count); + } + }*/ + } + } + } + IndexExprResult::AtLeast(_) => { + // In the `AtLeast` case some of the rows in the block list may be false + // positives. Therefore, we can not prune any fragments as we can not guarantee + // any fragments will not have matching rows. + } + } + } + + // Estimate row counts for fragments not covered by indices. + fragments + .iter() + .filter(|frag| !covered_frag_ids.contains(&frag.id)) + .for_each(|frag| { + // Estimate the number of rows in the fragment that satisfy the filter + let max_row_count = match frag.num_rows() { + Some(count) => count, + None => usize::MAX, + }; + frag_max_row_counts.insert(frag.id, max_row_count); + + // TODO - Query `ZoneMaps` to prune rows within fragments + // TODO - do we want to create splits with mixed covered / uncovered fragments? + }); + + // Bin pack fragments into splits for parallel processing + let max_rows_per_split = options + .and_then(|o| o.max_rows_per_split) + .unwrap_or(500_000); // TODO @hamersaw - default?!? + + let bins = Self::bin_pack(frag_max_row_counts, max_rows_per_split); + + // Convert bins to splits + let fragment_map: HashMap = fragments.iter().map(|f| (f.id, f)).collect(); + let splits = bins + .into_iter() + .map(|bin| { + let frags = bin + .into_iter() + .filter_map(|id| fragment_map.get(&id).map(|&f| f.clone())) + .collect(); + Split { fragments: frags } + }) + .collect(); + + Ok(splits) + } + + /// Packs IDs into bins where each bin's total count is less than `maximum_count`. + /// + /// Uses a first-fit decreasing algorithm: items are sorted by count in descending + /// order, then each item is placed in the first bin that has room for it. + fn bin_pack(items: HashMap, maximum_count: usize) -> Vec> { + // Convert to vec and sort by count descending for better packing + let mut items: Vec<(u64, usize)> = items.into_iter().collect(); + items.sort_by(|a, b| b.1.cmp(&a.1)); + + let mut bins: Vec<(Vec, usize)> = Vec::new(); // (ids, current_count) + + for (id, count) in items { + // Items that exceed the maximum get their own bin + if count >= maximum_count { + bins.push((vec![id], count)); + continue; + } + + // Find first bin with enough remaining capacity + let mut placed = false; + for (bin_ids, bin_count) in &mut bins { + if *bin_count + count < maximum_count { + bin_ids.push(id); + *bin_count += count; + placed = true; + break; + } + } + + // Create new bin if no existing bin has room + if !placed { + bins.push((vec![id], count)); + } + } + + bins.into_iter().map(|(ids, _)| ids).collect() + } + /// Empty Projection (useful for count queries) /// /// The row_address will be scanned (no I/O required) but not included in the output From c176bb9d056b1f69c1adfb9b54b687931b0a3797 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 22 Jan 2026 15:54:15 -0600 Subject: [PATCH 2/6] updated to use max_split_size_bytes Signed-off-by: Daniel Rammer --- python/python/lance/dataset.py | 28 ++++++++++ python/src/scanner.rs | 31 ++++++++++- rust/lance/src/dataset/scanner.rs | 88 +++++++++++++++++++++---------- 3 files changed, 119 insertions(+), 28 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 7b19ad72c80..e306957133f 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -5189,6 +5189,34 @@ def analyze_plan(self) -> str: return self._scanner.analyze_plan() + def plan_splits( + self, max_split_size_bytes: Optional[int] = None + ) -> List[List["FragmentMetadata"]]: + """Plan splits for distributed scanning. + + This method analyzes the scanner's filter and uses indices to determine + which fragments need to be scanned and approximately how many rows each + fragment will return. It then groups fragments into splits that can be + processed independently. + + The scanner estimates the size of each row based on the output schema + projection and uses that to determine how many rows fit within the + target split size. + + Parameters + ---------- + max_split_size_bytes : int, optional + The target maximum size in bytes for each split. Defaults to 128MB. + + Returns + ------- + List[List[FragmentMetadata]] + A list of splits, where each split is a list of FragmentMetadata objects. + Each split can be processed independently for distributed scanning. + """ + + return self._scanner.plan_splits(max_split_size_bytes=max_split_size_bytes) + class DatasetOptimizer: def __init__(self, dataset: LanceDataset): diff --git a/python/src/scanner.rs b/python/src/scanner.rs index 1e85af6711f..8e0706534b7 100644 --- a/python/src/scanner.rs +++ b/python/src/scanner.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::pyarrow::*; use arrow_array::RecordBatchReader; -use lance::dataset::scanner::ExecutionSummaryCounts; +use lance::dataset::scanner::{ExecutionSummaryCounts, SplitOptions}; use pyo3::prelude::*; use pyo3::pyclass; @@ -30,6 +30,7 @@ use pyo3::exceptions::PyValueError; use crate::reader::LanceReader; use crate::rt; use crate::schema::logical_arrow_schema; +use crate::utils::PyLance; /// This will be wrapped by a python class to provide /// additional functionality @@ -150,4 +151,32 @@ impl Scanner { Ok(PyArrowType(Box::new(reader))) } + + #[pyo3(signature = (max_split_size_bytes=None))] + fn plan_splits<'py>( + self_: PyRef<'py, Self>, + max_split_size_bytes: Option, + ) -> PyResult>>> { + let scanner = self_.scanner.clone(); + let mut options = SplitOptions::default(); + if let Some(size) = max_split_size_bytes { + options = options.with_max_split_size_bytes(size); + } + let splits = rt() + .spawn(Some(self_.py()), async move { + scanner.plan_splits(Some(options)).await + })? + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + splits + .into_iter() + .map(|split| { + split + .fragments + .into_iter() + .map(|frag| PyLance(frag).into_pyobject(self_.py())) + .collect::, _>>() + }) + .collect::, _>>() + } } diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index ae4555d6626..96d30e9f2df 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -52,7 +52,7 @@ use lance_core::datatypes::{ }; use lance_core::error::LanceOptionExt; use lance_core::utils::address::RowAddress; -use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap, RowSetOps}; +use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap}; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{ROW_ADDR, ROW_ID, ROW_OFFSET}; use lance_datafusion::exec::{ @@ -612,19 +612,19 @@ pub struct Split { } /// Options for configuring how splits are created. +#[derive(Debug, Clone, Default)] pub struct SplitOptions { - /// A suggested number of rows per split. The scanner will attempt to create splits - /// with approximately this many rows, but the actual number may vary depending on - /// the sizes of the fragments. - /// For example, if a single fragment has more rows than this value, we will not split - /// the fragment, but include it in a single split. - max_rows_per_split: Option, + /// The target maximum size in bytes for each split. The scanner estimates + /// the row size from the output schema and calculates how many rows fit + /// within this budget. Defaults to 128MB if not specified. + max_split_size_bytes: Option, } impl SplitOptions { - /// Create a new SplitOptions with the given max_rows_per_split. - pub fn new(max_rows_per_split: Option) -> Self { - Self { max_rows_per_split } + /// Set the target maximum size in bytes for each split. + pub fn with_max_split_size_bytes(mut self, max_split_size_bytes: usize) -> Self { + self.max_split_size_bytes = Some(max_split_size_bytes); + self } } @@ -886,7 +886,11 @@ impl Scanner { self.fragments.is_some() } - // TODO @hamersaw - docs + /// Plan splits for distributed or parallel scanning of the dataset. + /// + /// This method analyzes the fragments to be scanned and groups them into [`Split`]s + /// that can be processed independently by multiple workers or threads. It uses a + /// bin-packing algorithm to create balanced splits based on estimated row counts. pub async fn plan_splits(&self, options: Option) -> Result> { // Collect initial set of fragments to scan let fragments = if let Some(fragments) = self.fragments.as_ref() { @@ -927,9 +931,9 @@ impl Scanner { { Some(frag_bitmap) => frag_bitmap.len() as usize, None => { - // Since we know `bitmap.contains(frag.id as u32)` is - // true, this `None` means the fragment bitmap is full. - // Use the total number of rows in the fragment. + // Since we know `frag.id` is in the bitmap, this `None + // means the fragment bitmap is full. Use the total + // number of rows in the fragment. frag.num_rows().unwrap_or(usize::MAX) } }; @@ -941,16 +945,33 @@ impl Scanner { } } RowAddrMask::BlockList(bitmap) => { - let _blocked_frag_ids: HashSet = + // Iterate over covered fragments and update row counts + let blocked_frag_ids: HashSet = bitmap.fragments().into_iter().collect(); - // TODO @hamersaw - figure out how to handle block list pruning correctly - /*for frag in &covered_frags { - if !bitmap.contains(frag.id as u32) { - // All rows in the fragment match - frag_max_row_counts.insert(frag.id, max_row_count); + for frag in &covered_frags { + if !blocked_frag_ids.contains(&(frag.id as u32)) { + // Fragment is not blocked, so all rows are allowed + frag_max_row_counts + .insert(frag.id, frag.num_rows().unwrap_or(usize::MAX)); + } else { + match bitmap.get_fragment_bitmap(frag.id as u32) { + Some(frag_bitmap) => { + let blocked_row_count = frag_bitmap.len() as usize; + let row_count = match frag.num_rows() { + Some(row_count) => row_count - blocked_row_count, + None => usize::MAX, + }; + frag_max_row_counts.insert(frag.id, row_count); + } + None => { + // PRUNE fragment since no rows match + // Since we know `frag.id` is in the bitmap, this `None + // means the fragment bitmap is full. + } + } } - }*/ + } } } } @@ -973,15 +994,28 @@ impl Scanner { None => usize::MAX, }; frag_max_row_counts.insert(frag.id, max_row_count); - - // TODO - Query `ZoneMaps` to prune rows within fragments - // TODO - do we want to create splits with mixed covered / uncovered fragments? }); // Bin pack fragments into splits for parallel processing - let max_rows_per_split = options - .and_then(|o| o.max_rows_per_split) - .unwrap_or(500_000); // TODO @hamersaw - default?!? + const DEFAULT_SPLIT_SIZE: usize = 128 * 1024 * 1024; + const DEFAULT_VARIABLE_FIELD_SIZE: usize = 64; + + let target_split_size = options + .and_then(|o| o.max_split_size_bytes) + .unwrap_or(DEFAULT_SPLIT_SIZE); + + let output_schema = self.projection_plan.output_schema()?; + let estimated_row_size: usize = output_schema + .fields() + .iter() + .map(|f| { + f.data_type() + .byte_width_opt() + .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) + }) + .sum(); + + let max_rows_per_split = target_split_size / estimated_row_size.max(1); let bins = Self::bin_pack(frag_max_row_counts, max_rows_per_split); From dc752bb086aebb73b2bf337f7d4ae72548aed949 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 22 Jan 2026 16:19:09 -0600 Subject: [PATCH 3/6] added unit tests Signed-off-by: Daniel Rammer --- rust/lance/src/dataset/scanner.rs | 252 ++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 96d30e9f2df..24a72410b10 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -9298,4 +9298,256 @@ mod test { runtime.handle().metrics().num_alive_tasks() ); } + + #[test] + fn test_bin_pack_empty() { + let items: HashMap = HashMap::new(); + let bins = Scanner::bin_pack(items, 100); + assert!(bins.is_empty()); + } + + #[test] + fn test_bin_pack_single_item_fits() { + let items = HashMap::from([(1, 50)]); + let bins = Scanner::bin_pack(items, 100); + assert_eq!(bins.len(), 1); + assert_eq!(bins[0], vec![1]); + } + + #[test] + fn test_bin_pack_single_item_exceeds_maximum() { + let items = HashMap::from([(1, 200)]); + let bins = Scanner::bin_pack(items, 100); + assert_eq!(bins.len(), 1); + assert_eq!(bins[0], vec![1]); + } + + #[test] + fn test_bin_pack_single_item_equals_maximum() { + // Items with count >= maximum_count get their own bin + let items = HashMap::from([(1, 100)]); + let bins = Scanner::bin_pack(items, 100); + assert_eq!(bins.len(), 1); + assert_eq!(bins[0], vec![1]); + } + + #[test] + fn test_bin_pack_multiple_items_fit_one_bin() { + let items = HashMap::from([(1, 30), (2, 30), (3, 30)]); + let bins = Scanner::bin_pack(items, 100); + assert_eq!(bins.len(), 1); + let mut ids: Vec = bins[0].clone(); + ids.sort(); + assert_eq!(ids, vec![1, 2, 3]); + } + + #[test] + fn test_bin_pack_multiple_bins_needed() { + // Each item is 60, maximum is 100, so only one item per bin + // (60 + 60 = 120 >= 100, not strictly less than) + let items = HashMap::from([(1, 60), (2, 60), (3, 60)]); + let bins = Scanner::bin_pack(items, 100); + assert_eq!(bins.len(), 3); + } + + #[test] + fn test_bin_pack_mixed_sizes() { + // maximum = 100 + // Items: 70, 50, 40, 20, 10 + // Sorted descending: 70, 50, 40, 20, 10 + // Bin 1: 70, then try 50 (70+50=120 >= 100, no), try 40 (70+40=110 >= 100, no), + // try 20 (70+20=90 < 100, yes) -> [70, 20] + // Bin 2: 50, then try 40 (50+40=90 < 100, yes) -> [50, 40] + // Bin 3: try 10 in bin1 (90+10=100, NOT < 100), try bin2 (90+10=100, NOT < 100) -> [10] + let items = HashMap::from([(1, 70), (2, 50), (3, 40), (4, 20), (5, 10)]); + let bins = Scanner::bin_pack(items, 100); + + // Verify total items across all bins + let total_items: usize = bins.iter().map(|b| b.len()).sum(); + assert_eq!(total_items, 5); + + // Each bin's total count should be < maximum (or a single oversized item) + let item_map: HashMap = [(1, 70), (2, 50), (3, 40), (4, 20), (5, 10)].into(); + for bin in &bins { + let bin_total: usize = bin.iter().map(|id| item_map[id]).sum(); + // Bins with a single oversized item can exceed, but here all are < 100 + assert!(bin_total <= 100, "bin total {} exceeds maximum", bin_total); + } + } + + #[test] + fn test_bin_pack_oversized_items_get_own_bin() { + let items = HashMap::from([(1, 200), (2, 150), (3, 30)]); + let bins = Scanner::bin_pack(items, 100); + + // Oversized items (200, 150) each get their own bin + // Item 30 could fit in a new bin + assert_eq!(bins.len(), 3); + + // Each bin should have exactly one item + for bin in &bins { + assert_eq!(bin.len(), 1); + } + } + + #[test] + fn test_bin_pack_boundary_condition() { + // Test the strict less-than condition: bin_count + count < maximum_count + // Two items of 49 should fit in one bin (49 + 49 = 98 < 100) + let items = HashMap::from([(1, 49), (2, 49)]); + let bins = Scanner::bin_pack(items, 100); + assert_eq!(bins.len(), 1); + + // Two items of 50 should NOT fit in one bin (50 + 50 = 100, not < 100) + let items = HashMap::from([(1, 50), (2, 50)]); + let bins = Scanner::bin_pack(items, 100); + assert_eq!(bins.len(), 2); + } + + #[test] + fn test_bin_pack_all_ids_preserved() { + let items: HashMap = (0..10).map(|i| (i, 25)).collect(); + let bins = Scanner::bin_pack(items.clone(), 100); + + let mut all_ids: Vec = bins.into_iter().flatten().collect(); + all_ids.sort(); + let mut expected_ids: Vec = items.keys().copied().collect(); + expected_ids.sort(); + assert_eq!(all_ids, expected_ids); + } + + #[tokio::test] + async fn test_plan_splits_basic() { + // Create a dataset with 4 fragments of 100 rows each, single i32 column (4 bytes) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + let splits = dataset.scan().plan_splits(None).await.unwrap(); + + // Default split size is 128MB, each fragment has 100 rows of 4 bytes = 400 bytes + // max_rows_per_split = 128*1024*1024 / 4 = 33554432 + // All 400 rows fit in one split + assert_eq!(splits.len(), 1); + assert_eq!(splits[0].fragments.len(), 4); + } + + #[tokio::test] + async fn test_plan_splits_with_small_split_size() { + // Create 4 fragments of 100 rows, single i32 column (4 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Set split size to 200 bytes -> max_rows = 200/4 = 50 rows per split + // Each fragment has 100 rows >= 50, so each gets its own bin + let options = SplitOptions::default().with_max_split_size_bytes(200); + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + + assert_eq!(splits.len(), 4); + for split in &splits { + assert_eq!(split.fragments.len(), 1); + } + } + + #[tokio::test] + async fn test_plan_splits_grouping() { + // Create 4 fragments of 50 rows, single i32 column (4 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(50)) + .await + .unwrap(); + + // Set split size to 400 bytes -> max_rows = 400/4 = 100 rows per split + // Each fragment has 50 rows, so two fragments can fit per split (50+50=100, not < 100) + // Actually 50+50=100 is NOT < 100, so each fragment gets its own bin + let options = SplitOptions::default().with_max_split_size_bytes(400); + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 4); + + // With 404 bytes: max_rows = 404/4 = 101 rows per split + // Each fragment has 50 rows, 50+50=100 < 101, so two fragments per split + let options = SplitOptions::default().with_max_split_size_bytes(404); + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 2); + for split in &splits { + assert_eq!(split.fragments.len(), 2); + } + } + + #[tokio::test] + async fn test_plan_splits_with_projection() { + // Create dataset with two i32 columns (8 bytes per row) + let dataset = lance_datagen::gen_batch() + .col("a", array::step::()) + .col("b", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Full projection: 8 bytes per row, max_rows = 800/8 = 100 + // Each fragment has 100 rows >= 100, so each gets its own bin + let options = SplitOptions::default().with_max_split_size_bytes(800); + let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); + assert_eq!(splits.len(), 4); + + // Single column projection: 4 bytes per row, max_rows = 800/4 = 200 + // Each fragment has 100 rows, 100+100=200 NOT < 200, so each gets its own bin + let mut scanner = dataset.scan(); + scanner.project(&["a"]).unwrap(); + let splits = scanner + .plan_splits(Some(SplitOptions::default().with_max_split_size_bytes(800))) + .await + .unwrap(); + assert_eq!(splits.len(), 4); + + // Single column projection with larger budget: max_rows = 804/4 = 201 + // 100 + 100 = 200 < 201, so two fragments per split + let mut scanner = dataset.scan(); + scanner.project(&["a"]).unwrap(); + let splits = scanner + .plan_splits(Some(SplitOptions::default().with_max_split_size_bytes(804))) + .await + .unwrap(); + assert_eq!(splits.len(), 2); + } + + #[tokio::test] + async fn test_plan_splits_with_fragments() { + // Create dataset with 4 fragments + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + + // Only scan 2 specific fragments + let frags: Vec<_> = dataset.fragments()[..2].to_vec(); + + let mut scanner = dataset.scan(); + scanner.with_fragments(frags); + + // Large split size so everything fits in one split + let splits = scanner.plan_splits(None).await.unwrap(); + assert_eq!(splits.len(), 1); + assert_eq!(splits[0].fragments.len(), 2); + } + + #[tokio::test] + async fn test_plan_splits_single_fragment() { + let dataset = lance_datagen::gen_batch() + .col("i", array::step::()) + .into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(100)) + .await + .unwrap(); + + let splits = dataset.scan().plan_splits(None).await.unwrap(); + assert_eq!(splits.len(), 1); + assert_eq!(splits[0].fragments.len(), 1); + } } From f1650d6a025237e84bb0fd2e3127c011d86583b0 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 23 Jan 2026 09:27:57 -0600 Subject: [PATCH 4/6] adding SplitStrategy to allow users to specify split size or max rows Signed-off-by: Daniel Rammer --- python/src/scanner.rs | 11 +- rust/lance-core/src/utils/mask.rs | 12 +- rust/lance/src/dataset/scanner.rs | 214 +++++++++++------------------- 3 files changed, 85 insertions(+), 152 deletions(-) diff --git a/python/src/scanner.rs b/python/src/scanner.rs index 8e0706534b7..bf49106dc47 100644 --- a/python/src/scanner.rs +++ b/python/src/scanner.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::pyarrow::*; use arrow_array::RecordBatchReader; -use lance::dataset::scanner::{ExecutionSummaryCounts, SplitOptions}; +use lance::dataset::scanner::{ExecutionSummaryCounts, SplitPackStrategy}; use pyo3::prelude::*; use pyo3::pyclass; @@ -158,13 +158,10 @@ impl Scanner { max_split_size_bytes: Option, ) -> PyResult>>> { let scanner = self_.scanner.clone(); - let mut options = SplitOptions::default(); - if let Some(size) = max_split_size_bytes { - options = options.with_max_split_size_bytes(size); - } + let strategy = max_split_size_bytes.map(SplitPackStrategy::MaxSizeBytes); let splits = rt() .spawn(Some(self_.py()), async move { - scanner.plan_splits(Some(options)).await + scanner.plan_splits(strategy).await })? .map_err(|err| PyValueError::new_err(err.to_string()))?; @@ -174,7 +171,7 @@ impl Scanner { split .fragments .into_iter() - .map(|frag| PyLance(frag).into_pyobject(self_.py())) + .map(|sf| PyLance(sf.fragment).into_pyobject(self_.py())) .collect::, _>>() }) .collect::, _>>() diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index 30a75db5060..0c21027b38a 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -333,13 +333,6 @@ impl RowAddrSelection { res } } - - pub fn len(&self) -> Option { - match self { - Self::Full => None, - Self::Partial(bitmap) => Some(bitmap.len()), - } - } } impl RowSetOps for RowAddrTreeMap { @@ -352,7 +345,10 @@ impl RowSetOps for RowAddrTreeMap { fn len(&self) -> Option { self.inner .values() - .map(|row_addr_selection| row_addr_selection.len()) + .map(|row_addr_selection| match row_addr_selection { + RowAddrSelection::Full => None, + RowAddrSelection::Partial(indices) => Some(indices.len()), + }) .try_fold(0_u64, |acc, next| next.map(|next| next + acc)) } diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 24a72410b10..866f0f774ab 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -6,7 +6,6 @@ use std::ops::Range; use std::pin::Pin; use std::sync::{Arc, LazyLock}; use std::task::{Context, Poll}; -use std::usize; use arrow::array::AsArray; use arrow_array::{Array, Float32Array, Int64Array, RecordBatch}; @@ -608,24 +607,29 @@ pub struct Scanner { /// Splits can be used to distribute scanning work across multiple workers or threads. #[derive(Debug, Clone)] pub struct Split { - pub fragments: Vec, + pub fragments: Vec, } -/// Options for configuring how splits are created. -#[derive(Debug, Clone, Default)] -pub struct SplitOptions { - /// The target maximum size in bytes for each split. The scanner estimates - /// the row size from the output schema and calculates how many rows fit - /// within this budget. Defaults to 128MB if not specified. - max_split_size_bytes: Option, +/// A fragment within a [`Split`], along with metadata about the expected +/// number of rows that will be scanned from it. +#[derive(Debug, Clone)] +pub struct SplitFragment { + /// The fragment to scan. + pub fragment: Fragment, + /// An upper bound on the number of rows that will be read from this + /// fragment after applying any filters or index pruning. + pub max_row_count: usize, } -impl SplitOptions { - /// Set the target maximum size in bytes for each split. - pub fn with_max_split_size_bytes(mut self, max_split_size_bytes: usize) -> Self { - self.max_split_size_bytes = Some(max_split_size_bytes); - self - } +/// Strategy for packing fragments into splits. +#[derive(Debug, Clone)] +pub enum SplitPackStrategy { + /// Target a maximum size in bytes per split. The scanner estimates the row + /// size from the output schema and calculates how many rows fit within this + /// budget. + MaxSizeBytes(usize), + /// Target a maximum number of rows per split. + MaxRowCount(usize), } /// Represents a user-requested take operation @@ -891,7 +895,7 @@ impl Scanner { /// This method analyzes the fragments to be scanned and groups them into [`Split`]s /// that can be processed independently by multiple workers or threads. It uses a /// bin-packing algorithm to create balanced splits based on estimated row counts. - pub async fn plan_splits(&self, options: Option) -> Result> { + pub async fn plan_splits(&self, strategy: Option) -> Result> { // Collect initial set of fragments to scan let fragments = if let Some(fragments) = self.fragments.as_ref() { Arc::new(fragments.clone()) @@ -899,7 +903,7 @@ impl Scanner { Arc::new(self.dataset.fragments().as_ref().clone()) }; - // Use indices to prune fragments + // Use indices to prune fragments and compute max row counts per fragment let mut frag_max_row_counts: HashMap = HashMap::new(); let mut covered_frag_ids: HashSet<_> = HashSet::new(); @@ -1000,24 +1004,37 @@ impl Scanner { const DEFAULT_SPLIT_SIZE: usize = 128 * 1024 * 1024; const DEFAULT_VARIABLE_FIELD_SIZE: usize = 64; - let target_split_size = options - .and_then(|o| o.max_split_size_bytes) - .unwrap_or(DEFAULT_SPLIT_SIZE); - - let output_schema = self.projection_plan.output_schema()?; - let estimated_row_size: usize = output_schema - .fields() - .iter() - .map(|f| { - f.data_type() - .byte_width_opt() - .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) - }) - .sum(); - - let max_rows_per_split = target_split_size / estimated_row_size.max(1); + let max_rows_per_split = match strategy { + Some(SplitPackStrategy::MaxRowCount(max_row_count)) => max_row_count, + Some(SplitPackStrategy::MaxSizeBytes(max_bytes)) => { + let output_schema = self.projection_plan.output_schema()?; + let estimated_row_size: usize = output_schema + .fields() + .iter() + .map(|f| { + f.data_type() + .byte_width_opt() + .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) + }) + .sum(); + max_bytes / estimated_row_size.max(1) + } + None => { + let output_schema = self.projection_plan.output_schema()?; + let estimated_row_size: usize = output_schema + .fields() + .iter() + .map(|f| { + f.data_type() + .byte_width_opt() + .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) + }) + .sum(); + DEFAULT_SPLIT_SIZE / estimated_row_size.max(1) + } + }; - let bins = Self::bin_pack(frag_max_row_counts, max_rows_per_split); + let bins = Self::bin_pack(&frag_max_row_counts, max_rows_per_split); // Convert bins to splits let fragment_map: HashMap = fragments.iter().map(|f| (f.id, f)).collect(); @@ -1026,7 +1043,15 @@ impl Scanner { .map(|bin| { let frags = bin .into_iter() - .filter_map(|id| fragment_map.get(&id).map(|&f| f.clone())) + .filter_map(|id| { + fragment_map.get(&id).map(|&f| SplitFragment { + fragment: f.clone(), + max_row_count: frag_max_row_counts + .get(&id) + .copied() + .unwrap_or(usize::MAX), + }) + }) .collect(); Split { fragments: frags } }) @@ -1039,16 +1064,16 @@ impl Scanner { /// /// Uses a first-fit decreasing algorithm: items are sorted by count in descending /// order, then each item is placed in the first bin that has room for it. - fn bin_pack(items: HashMap, maximum_count: usize) -> Vec> { + fn bin_pack(items: &HashMap, maximum_count: usize) -> Vec> { // Convert to vec and sort by count descending for better packing - let mut items: Vec<(u64, usize)> = items.into_iter().collect(); + let mut items: Vec<(u64, usize)> = items.iter().map(|(&k, &v)| (k, v)).collect(); items.sort_by(|a, b| b.1.cmp(&a.1)); let mut bins: Vec<(Vec, usize)> = Vec::new(); // (ids, current_count) for (id, count) in items { // Items that exceed the maximum get their own bin - if count >= maximum_count { + if count > maximum_count { bins.push((vec![id], count)); continue; } @@ -1056,7 +1081,7 @@ impl Scanner { // Find first bin with enough remaining capacity let mut placed = false; for (bin_ids, bin_count) in &mut bins { - if *bin_count + count < maximum_count { + if *bin_count + count <= maximum_count { bin_ids.push(id); *bin_count += count; placed = true; @@ -9302,75 +9327,29 @@ mod test { #[test] fn test_bin_pack_empty() { let items: HashMap = HashMap::new(); - let bins = Scanner::bin_pack(items, 100); + let bins = Scanner::bin_pack(&items, 100); assert!(bins.is_empty()); } - #[test] - fn test_bin_pack_single_item_fits() { - let items = HashMap::from([(1, 50)]); - let bins = Scanner::bin_pack(items, 100); - assert_eq!(bins.len(), 1); - assert_eq!(bins[0], vec![1]); - } - - #[test] - fn test_bin_pack_single_item_exceeds_maximum() { - let items = HashMap::from([(1, 200)]); - let bins = Scanner::bin_pack(items, 100); - assert_eq!(bins.len(), 1); - assert_eq!(bins[0], vec![1]); - } - - #[test] - fn test_bin_pack_single_item_equals_maximum() { - // Items with count >= maximum_count get their own bin - let items = HashMap::from([(1, 100)]); - let bins = Scanner::bin_pack(items, 100); - assert_eq!(bins.len(), 1); - assert_eq!(bins[0], vec![1]); - } - - #[test] - fn test_bin_pack_multiple_items_fit_one_bin() { - let items = HashMap::from([(1, 30), (2, 30), (3, 30)]); - let bins = Scanner::bin_pack(items, 100); - assert_eq!(bins.len(), 1); - let mut ids: Vec = bins[0].clone(); - ids.sort(); - assert_eq!(ids, vec![1, 2, 3]); - } - - #[test] - fn test_bin_pack_multiple_bins_needed() { - // Each item is 60, maximum is 100, so only one item per bin - // (60 + 60 = 120 >= 100, not strictly less than) - let items = HashMap::from([(1, 60), (2, 60), (3, 60)]); - let bins = Scanner::bin_pack(items, 100); - assert_eq!(bins.len(), 3); - } - #[test] fn test_bin_pack_mixed_sizes() { // maximum = 100 // Items: 70, 50, 40, 20, 10 // Sorted descending: 70, 50, 40, 20, 10 - // Bin 1: 70, then try 50 (70+50=120 >= 100, no), try 40 (70+40=110 >= 100, no), - // try 20 (70+20=90 < 100, yes) -> [70, 20] - // Bin 2: 50, then try 40 (50+40=90 < 100, yes) -> [50, 40] - // Bin 3: try 10 in bin1 (90+10=100, NOT < 100), try bin2 (90+10=100, NOT < 100) -> [10] + // Bin 1: 70, try 50 (70+50=120 > 100, no), try 40 (70+40=110 > 100, no), + // try 20 (70+20=90 <= 100, yes), try 10 (90+10=100 <= 100, yes) -> [70, 20, 10] + // Bin 2: 50, try 40 (50+40=90 <= 100, yes) -> [50, 40] let items = HashMap::from([(1, 70), (2, 50), (3, 40), (4, 20), (5, 10)]); - let bins = Scanner::bin_pack(items, 100); + let bins = Scanner::bin_pack(&items, 100); // Verify total items across all bins let total_items: usize = bins.iter().map(|b| b.len()).sum(); assert_eq!(total_items, 5); - // Each bin's total count should be < maximum (or a single oversized item) + // Each bin's total count should be <= maximum (or a single oversized item) let item_map: HashMap = [(1, 70), (2, 50), (3, 40), (4, 20), (5, 10)].into(); for bin in &bins { let bin_total: usize = bin.iter().map(|id| item_map[id]).sum(); - // Bins with a single oversized item can exceed, but here all are < 100 assert!(bin_total <= 100, "bin total {} exceeds maximum", bin_total); } } @@ -9378,7 +9357,7 @@ mod test { #[test] fn test_bin_pack_oversized_items_get_own_bin() { let items = HashMap::from([(1, 200), (2, 150), (3, 30)]); - let bins = Scanner::bin_pack(items, 100); + let bins = Scanner::bin_pack(&items, 100); // Oversized items (200, 150) each get their own bin // Item 30 could fit in a new bin @@ -9390,32 +9369,6 @@ mod test { } } - #[test] - fn test_bin_pack_boundary_condition() { - // Test the strict less-than condition: bin_count + count < maximum_count - // Two items of 49 should fit in one bin (49 + 49 = 98 < 100) - let items = HashMap::from([(1, 49), (2, 49)]); - let bins = Scanner::bin_pack(items, 100); - assert_eq!(bins.len(), 1); - - // Two items of 50 should NOT fit in one bin (50 + 50 = 100, not < 100) - let items = HashMap::from([(1, 50), (2, 50)]); - let bins = Scanner::bin_pack(items, 100); - assert_eq!(bins.len(), 2); - } - - #[test] - fn test_bin_pack_all_ids_preserved() { - let items: HashMap = (0..10).map(|i| (i, 25)).collect(); - let bins = Scanner::bin_pack(items.clone(), 100); - - let mut all_ids: Vec = bins.into_iter().flatten().collect(); - all_ids.sort(); - let mut expected_ids: Vec = items.keys().copied().collect(); - expected_ids.sort(); - assert_eq!(all_ids, expected_ids); - } - #[tokio::test] async fn test_plan_splits_basic() { // Create a dataset with 4 fragments of 100 rows each, single i32 column (4 bytes) @@ -9445,7 +9398,7 @@ mod test { // Set split size to 200 bytes -> max_rows = 200/4 = 50 rows per split // Each fragment has 100 rows >= 50, so each gets its own bin - let options = SplitOptions::default().with_max_split_size_bytes(200); + let options = SplitPackStrategy::MaxSizeBytes(200); let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); assert_eq!(splits.len(), 4); @@ -9466,13 +9419,13 @@ mod test { // Set split size to 400 bytes -> max_rows = 400/4 = 100 rows per split // Each fragment has 50 rows, so two fragments can fit per split (50+50=100, not < 100) // Actually 50+50=100 is NOT < 100, so each fragment gets its own bin - let options = SplitOptions::default().with_max_split_size_bytes(400); + let options = SplitPackStrategy::MaxSizeBytes(400); let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); assert_eq!(splits.len(), 4); // With 404 bytes: max_rows = 404/4 = 101 rows per split // Each fragment has 50 rows, 50+50=100 < 101, so two fragments per split - let options = SplitOptions::default().with_max_split_size_bytes(404); + let options = SplitPackStrategy::MaxSizeBytes(404); let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); assert_eq!(splits.len(), 2); for split in &splits { @@ -9492,7 +9445,7 @@ mod test { // Full projection: 8 bytes per row, max_rows = 800/8 = 100 // Each fragment has 100 rows >= 100, so each gets its own bin - let options = SplitOptions::default().with_max_split_size_bytes(800); + let options = SplitPackStrategy::MaxSizeBytes(800); let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); assert_eq!(splits.len(), 4); @@ -9501,7 +9454,7 @@ mod test { let mut scanner = dataset.scan(); scanner.project(&["a"]).unwrap(); let splits = scanner - .plan_splits(Some(SplitOptions::default().with_max_split_size_bytes(800))) + .plan_splits(Some(SplitPackStrategy::MaxSizeBytes(800))) .await .unwrap(); assert_eq!(splits.len(), 4); @@ -9511,7 +9464,7 @@ mod test { let mut scanner = dataset.scan(); scanner.project(&["a"]).unwrap(); let splits = scanner - .plan_splits(Some(SplitOptions::default().with_max_split_size_bytes(804))) + .plan_splits(Some(SplitPackStrategy::MaxSizeBytes(804))) .await .unwrap(); assert_eq!(splits.len(), 2); @@ -9537,17 +9490,4 @@ mod test { assert_eq!(splits.len(), 1); assert_eq!(splits[0].fragments.len(), 2); } - - #[tokio::test] - async fn test_plan_splits_single_fragment() { - let dataset = lance_datagen::gen_batch() - .col("i", array::step::()) - .into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(100)) - .await - .unwrap(); - - let splits = dataset.scan().plan_splits(None).await.unwrap(); - assert_eq!(splits.len(), 1); - assert_eq!(splits[0].fragments.len(), 1); - } } From 53338604aab5e369c9539b7327f9a875b86aa0d2 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Mon, 26 Jan 2026 09:29:49 -0600 Subject: [PATCH 5/6] updated to use separate max_size_bytes and max_row_count parameters and use the min if both provided Signed-off-by: Daniel Rammer --- python/python/lance/dataset.py | 13 +++- python/src/scanner.rs | 18 +++-- rust/lance/src/dataset/scanner.rs | 105 ++++++++++++++++++++---------- 3 files changed, 94 insertions(+), 42 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index e306957133f..a2ee300a1f2 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -5190,7 +5190,9 @@ def analyze_plan(self) -> str: return self._scanner.analyze_plan() def plan_splits( - self, max_split_size_bytes: Optional[int] = None + self, + max_size_bytes: Optional[int] = None, + max_row_count: Optional[int] = None, ) -> List[List["FragmentMetadata"]]: """Plan splits for distributed scanning. @@ -5205,8 +5207,11 @@ def plan_splits( Parameters ---------- - max_split_size_bytes : int, optional + max_size_bytes : int, optional The target maximum size in bytes for each split. Defaults to 128MB. + max_row_count : int, optional + The maximum number of rows for each split. If specified, this takes + precedence over max_size_bytes. Returns ------- @@ -5215,7 +5220,9 @@ def plan_splits( Each split can be processed independently for distributed scanning. """ - return self._scanner.plan_splits(max_split_size_bytes=max_split_size_bytes) + return self._scanner.plan_splits( + max_size_bytes=max_size_bytes, max_row_count=max_row_count + ) class DatasetOptimizer: diff --git a/python/src/scanner.rs b/python/src/scanner.rs index bf49106dc47..b4e19e8e10d 100644 --- a/python/src/scanner.rs +++ b/python/src/scanner.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::pyarrow::*; use arrow_array::RecordBatchReader; -use lance::dataset::scanner::{ExecutionSummaryCounts, SplitPackStrategy}; +use lance::dataset::scanner::{ExecutionSummaryCounts, SplitOptions}; use pyo3::prelude::*; use pyo3::pyclass; @@ -152,16 +152,24 @@ impl Scanner { Ok(PyArrowType(Box::new(reader))) } - #[pyo3(signature = (max_split_size_bytes=None))] + #[pyo3(signature = (max_size_bytes=None, max_row_count=None))] fn plan_splits<'py>( self_: PyRef<'py, Self>, - max_split_size_bytes: Option, + max_size_bytes: Option, + max_row_count: Option, ) -> PyResult>>> { let scanner = self_.scanner.clone(); - let strategy = max_split_size_bytes.map(SplitPackStrategy::MaxSizeBytes); + let options = if max_size_bytes.is_some() || max_row_count.is_some() { + Some(SplitOptions { + max_size_bytes, + max_row_count, + }) + } else { + None + }; let splits = rt() .spawn(Some(self_.py()), async move { - scanner.plan_splits(strategy).await + scanner.plan_splits(options).await })? .map_err(|err| PyValueError::new_err(err.to_string()))?; diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 866f0f774ab..27b2579c6ac 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -632,6 +632,22 @@ pub enum SplitPackStrategy { MaxRowCount(usize), } +/// Options for configuring split generation. +/// +/// This struct allows specifying constraints on the maximum size and row count +/// for splits. Both fields are optional; if neither is set, default behavior +/// will be used. +#[derive(Debug, Clone, Default)] +pub struct SplitOptions { + /// Maximum size in bytes per split. + /// + /// The scanner estimates the row size from the output schema and calculates + /// how many rows fit within this budget. + pub max_size_bytes: Option, + /// Maximum number of rows per split. + pub max_row_count: Option, +} + /// Represents a user-requested take operation #[derive(Debug, Clone)] pub enum TakeOperation { @@ -895,7 +911,12 @@ impl Scanner { /// This method analyzes the fragments to be scanned and groups them into [`Split`]s /// that can be processed independently by multiple workers or threads. It uses a /// bin-packing algorithm to create balanced splits based on estimated row counts. - pub async fn plan_splits(&self, strategy: Option) -> Result> { + /// + /// The maximum number of rows per split is determined by taking the minimum of: + /// - `max_row_count` from [`SplitOptions`] (if provided) + /// - The estimated row count from `max_size_bytes` in [`SplitOptions`] (if provided) + /// - If neither is provided, uses a default split size of 128MB to estimate row count + pub async fn plan_splits(&self, options: Option) -> Result> { // Collect initial set of fragments to scan let fragments = if let Some(fragments) = self.fragments.as_ref() { Arc::new(fragments.clone()) @@ -1004,34 +1025,32 @@ impl Scanner { const DEFAULT_SPLIT_SIZE: usize = 128 * 1024 * 1024; const DEFAULT_VARIABLE_FIELD_SIZE: usize = 64; - let max_rows_per_split = match strategy { - Some(SplitPackStrategy::MaxRowCount(max_row_count)) => max_row_count, - Some(SplitPackStrategy::MaxSizeBytes(max_bytes)) => { - let output_schema = self.projection_plan.output_schema()?; - let estimated_row_size: usize = output_schema - .fields() - .iter() - .map(|f| { - f.data_type() - .byte_width_opt() - .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) - }) - .sum(); - max_bytes / estimated_row_size.max(1) - } - None => { - let output_schema = self.projection_plan.output_schema()?; - let estimated_row_size: usize = output_schema - .fields() - .iter() - .map(|f| { - f.data_type() - .byte_width_opt() - .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) - }) - .sum(); - DEFAULT_SPLIT_SIZE / estimated_row_size.max(1) + let options = options.unwrap_or_default(); + + // Helper to estimate row count from a byte size + let estimate_rows_from_bytes = |max_bytes: usize| -> Result { + let output_schema = self.projection_plan.output_schema()?; + let estimated_row_size: usize = output_schema + .fields() + .iter() + .map(|f| { + f.data_type() + .byte_width_opt() + .unwrap_or(DEFAULT_VARIABLE_FIELD_SIZE) + }) + .sum(); + Ok(max_bytes / estimated_row_size.max(1)) + }; + + let max_rows_per_split = match (options.max_row_count, options.max_size_bytes) { + (Some(max_rows), Some(max_bytes)) => { + // Use the minimum of both constraints + let rows_from_bytes = estimate_rows_from_bytes(max_bytes)?; + max_rows.min(rows_from_bytes) } + (Some(max_rows), None) => max_rows, + (None, Some(max_bytes)) => estimate_rows_from_bytes(max_bytes)?, + (None, None) => estimate_rows_from_bytes(DEFAULT_SPLIT_SIZE)?, }; let bins = Self::bin_pack(&frag_max_row_counts, max_rows_per_split); @@ -9398,7 +9417,10 @@ mod test { // Set split size to 200 bytes -> max_rows = 200/4 = 50 rows per split // Each fragment has 100 rows >= 50, so each gets its own bin - let options = SplitPackStrategy::MaxSizeBytes(200); + let options = SplitOptions { + max_size_bytes: Some(200), + ..Default::default() + }; let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); assert_eq!(splits.len(), 4); @@ -9419,13 +9441,19 @@ mod test { // Set split size to 400 bytes -> max_rows = 400/4 = 100 rows per split // Each fragment has 50 rows, so two fragments can fit per split (50+50=100, not < 100) // Actually 50+50=100 is NOT < 100, so each fragment gets its own bin - let options = SplitPackStrategy::MaxSizeBytes(400); + let options = SplitOptions { + max_size_bytes: Some(400), + ..Default::default() + }; let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); assert_eq!(splits.len(), 4); // With 404 bytes: max_rows = 404/4 = 101 rows per split // Each fragment has 50 rows, 50+50=100 < 101, so two fragments per split - let options = SplitPackStrategy::MaxSizeBytes(404); + let options = SplitOptions { + max_size_bytes: Some(404), + ..Default::default() + }; let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); assert_eq!(splits.len(), 2); for split in &splits { @@ -9445,7 +9473,10 @@ mod test { // Full projection: 8 bytes per row, max_rows = 800/8 = 100 // Each fragment has 100 rows >= 100, so each gets its own bin - let options = SplitPackStrategy::MaxSizeBytes(800); + let options = SplitOptions { + max_size_bytes: Some(800), + ..Default::default() + }; let splits = dataset.scan().plan_splits(Some(options)).await.unwrap(); assert_eq!(splits.len(), 4); @@ -9454,7 +9485,10 @@ mod test { let mut scanner = dataset.scan(); scanner.project(&["a"]).unwrap(); let splits = scanner - .plan_splits(Some(SplitPackStrategy::MaxSizeBytes(800))) + .plan_splits(Some(SplitOptions { + max_size_bytes: Some(800), + ..Default::default() + })) .await .unwrap(); assert_eq!(splits.len(), 4); @@ -9464,7 +9498,10 @@ mod test { let mut scanner = dataset.scan(); scanner.project(&["a"]).unwrap(); let splits = scanner - .plan_splits(Some(SplitPackStrategy::MaxSizeBytes(804))) + .plan_splits(Some(SplitOptions { + max_size_bytes: Some(804), + ..Default::default() + })) .await .unwrap(); assert_eq!(splits.len(), 2); From 300f9db458a17ff9c0a8342c304da31301675f67 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 27 Jan 2026 11:29:33 -0600 Subject: [PATCH 6/6] added to jni Signed-off-by: Daniel Rammer --- java/lance-jni/src/blocking_scanner.rs | 91 ++++++++++++++++++- .../main/java/org/lance/ipc/LanceScanner.java | 34 +++++++ java/src/main/java/org/lance/ipc/Split.java | 71 +++++++++++++++ .../java/org/lance/ipc/SplitFragment.java | 87 ++++++++++++++++++ .../main/java/org/lance/ipc/SplitOptions.java | 90 ++++++++++++++++++ java/src/test/java/org/lance/ScannerTest.java | 81 +++++++++++++++++ 6 files changed, 450 insertions(+), 4 deletions(-) create mode 100644 java/src/main/java/org/lance/ipc/Split.java create mode 100644 java/src/main/java/org/lance/ipc/SplitFragment.java create mode 100644 java/src/main/java/org/lance/ipc/SplitOptions.java diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index 8c6bc402544..9400b1111c6 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -5,14 +5,16 @@ use std::sync::Arc; use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; -use crate::traits::{import_vec_from_method, import_vec_to_rust}; +use crate::traits::{export_vec, import_vec_from_method, import_vec_to_rust, IntoJava}; use arrow::array::Float32Array; use arrow::{ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream}; use arrow_schema::SchemaRef; -use jni::objects::{JObject, JString}; +use jni::objects::{JObject, JString, JValueGen}; use jni::sys::{jboolean, jint, JNI_TRUE}; use jni::{sys::jlong, JNIEnv}; -use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner}; +use lance::dataset::scanner::{ + ColumnOrdering, DatasetRecordBatchStream, Scanner, Split, SplitFragment, SplitOptions, +}; use lance_index::scalar::inverted::query::{ BooleanQuery as FtsBooleanQuery, BoostQuery as FtsBoostQuery, FtsQuery, MatchQuery as FtsMatchQuery, MultiMatchQuery as FtsMultiMatchQuery, Occur as FtsOccur, @@ -24,7 +26,6 @@ use lance_linalg::distance::DistanceType; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, - traits::IntoJava, RT, }; @@ -56,6 +57,11 @@ impl BlockingScanner { let res = RT.block_on(self.inner.count_rows())?; Ok(res) } + + pub fn plan_splits(&self, options: Option) -> Result> { + let res = RT.block_on(self.inner.plan_splits(options))?; + Ok(res) + } } fn build_full_text_search_query<'a>(env: &mut JNIEnv<'a>, java_obj: JObject) -> Result { @@ -481,3 +487,80 @@ fn inner_count_rows(env: &mut JNIEnv, j_scanner: JObject) -> Result { unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?; scanner_guard.count_rows() } + +#[no_mangle] +pub extern "system" fn Java_org_lance_ipc_LanceScanner_nativePlanSplits<'local>( + mut env: JNIEnv<'local>, + j_scanner: JObject, + options_obj: JObject, // Optional +) -> JObject<'local> { + ok_or_throw!(env, inner_plan_splits(&mut env, j_scanner, options_obj)) +} + +fn inner_plan_splits<'local>( + env: &mut JNIEnv<'local>, + j_scanner: JObject, + options_obj: JObject, +) -> Result> { + let options = extract_split_options(env, &options_obj)?; + let splits = { + let scanner_guard = + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?; + scanner_guard.plan_splits(options)? + }; + export_vec(env, &splits) +} + +fn extract_split_options(env: &mut JNIEnv, options_obj: &JObject) -> Result> { + if options_obj.is_null() { + return Ok(None); + } + + let is_present = env.call_method(options_obj, "isPresent", "()Z", &[])?.z()?; + + if !is_present { + return Ok(None); + } + + let options_inner = env + .call_method(options_obj, "get", "()Ljava/lang/Object;", &[])? + .l()?; + + let max_size_bytes = env.get_optional_i64_from_method(&options_inner, "getMaxSizeBytes")?; + let max_row_count = env.get_optional_i64_from_method(&options_inner, "getMaxRowCount")?; + + Ok(Some(SplitOptions { + max_size_bytes: max_size_bytes.map(|v| v as usize), + max_row_count: max_row_count.map(|v| v as usize), + })) +} + +const SPLIT_CLASS: &str = "org/lance/ipc/Split"; +const SPLIT_CONSTRUCTOR_SIG: &str = "(Ljava/util/List;)V"; +const SPLIT_FRAGMENT_CLASS: &str = "org/lance/ipc/SplitFragment"; +const SPLIT_FRAGMENT_CONSTRUCTOR_SIG: &str = "(Lorg/lance/FragmentMetadata;J)V"; + +impl IntoJava for &SplitFragment { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let fragment = self.fragment.into_java(env)?; + Ok(env.new_object( + SPLIT_FRAGMENT_CLASS, + SPLIT_FRAGMENT_CONSTRUCTOR_SIG, + &[ + JValueGen::Object(&fragment), + JValueGen::Long(self.max_row_count as i64), + ], + )?) + } +} + +impl IntoJava for &Split { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let fragments = export_vec(env, &self.fragments)?; + Ok(env.new_object( + SPLIT_CLASS, + SPLIT_CONSTRUCTOR_SIG, + &[JValueGen::Object(&fragments)], + )?) + } +} diff --git a/java/src/main/java/org/lance/ipc/LanceScanner.java b/java/src/main/java/org/lance/ipc/LanceScanner.java index 804b7ea22f3..caab3ee74c3 100644 --- a/java/src/main/java/org/lance/ipc/LanceScanner.java +++ b/java/src/main/java/org/lance/ipc/LanceScanner.java @@ -164,4 +164,38 @@ public long countRows() { } private native long nativeCountRows(); + + /** + * Plan splits for parallel scanning of the dataset. + * + *

Splits can be used to distribute scanning work across multiple workers or threads. Each + * split contains one or more fragments that can be scanned together. + * + * @return a list of splits for parallel scanning + */ + public List planSplits() { + return planSplits(Optional.empty()); + } + + /** + * Plan splits for parallel scanning of the dataset with custom options. + * + *

Splits can be used to distribute scanning work across multiple workers or threads. Each + * split contains one or more fragments that can be scanned together. + * + * @param options options for configuring split generation + * @return a list of splits for parallel scanning + */ + public List planSplits(SplitOptions options) { + return planSplits(Optional.ofNullable(options)); + } + + private List planSplits(Optional options) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeScannerHandle != 0, "Scanner is closed"); + return nativePlanSplits(options); + } + } + + private native List nativePlanSplits(Optional options); } diff --git a/java/src/main/java/org/lance/ipc/Split.java b/java/src/main/java/org/lance/ipc/Split.java new file mode 100644 index 00000000000..b053c59536e --- /dev/null +++ b/java/src/main/java/org/lance/ipc/Split.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import com.google.common.base.MoreObjects; + +import java.io.Serializable; +import java.util.List; +import java.util.Objects; + +/** + * Represents a split for parallel scanning of fragments. + * + *

A split contains one or more fragments that can be scanned together. Splits can be used to + * distribute scanning work across multiple workers or threads. + */ +public class Split implements Serializable { + private static final long serialVersionUID = 1L; + private final List fragments; + + /** + * Creates a new Split. + * + * @param fragments the list of fragments in this split + */ + public Split(List fragments) { + this.fragments = fragments; + } + + /** + * Returns the list of fragments in this split. + * + * @return the list of split fragments + */ + public List getFragments() { + return fragments; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Split split = (Split) o; + return Objects.equals(fragments, split.fragments); + } + + @Override + public int hashCode() { + return Objects.hash(fragments); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("fragments", fragments).toString(); + } +} diff --git a/java/src/main/java/org/lance/ipc/SplitFragment.java b/java/src/main/java/org/lance/ipc/SplitFragment.java new file mode 100644 index 00000000000..63ec1c981b2 --- /dev/null +++ b/java/src/main/java/org/lance/ipc/SplitFragment.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import org.lance.FragmentMetadata; + +import com.google.common.base.MoreObjects; + +import java.io.Serializable; +import java.util.Objects; + +/** + * A fragment within a {@link Split}, along with metadata about the expected number of rows that + * will be scanned from it. + */ +public class SplitFragment implements Serializable { + private static final long serialVersionUID = 1L; + private final FragmentMetadata fragment; + private final long maxRowCount; + + /** + * Creates a new SplitFragment. + * + * @param fragment the fragment metadata + * @param maxRowCount an upper bound on the number of rows that will be read from this fragment + * after applying any filters or index pruning + */ + public SplitFragment(FragmentMetadata fragment, long maxRowCount) { + this.fragment = fragment; + this.maxRowCount = maxRowCount; + } + + /** + * Returns the fragment metadata. + * + * @return the fragment metadata + */ + public FragmentMetadata getFragment() { + return fragment; + } + + /** + * Returns an upper bound on the number of rows that will be read from this fragment after + * applying any filters or index pruning. + * + * @return the maximum row count + */ + public long getMaxRowCount() { + return maxRowCount; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SplitFragment that = (SplitFragment) o; + return maxRowCount == that.maxRowCount && Objects.equals(fragment, that.fragment); + } + + @Override + public int hashCode() { + return Objects.hash(fragment, maxRowCount); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("fragment", fragment) + .add("maxRowCount", maxRowCount) + .toString(); + } +} diff --git a/java/src/main/java/org/lance/ipc/SplitOptions.java b/java/src/main/java/org/lance/ipc/SplitOptions.java new file mode 100644 index 00000000000..bd9645bd33b --- /dev/null +++ b/java/src/main/java/org/lance/ipc/SplitOptions.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import java.util.Optional; + +/** + * Options for configuring split generation in a scanner. + * + *

This class allows specifying constraints on the maximum size and row count for splits. Both + * fields are optional; if neither is set, default behavior will be used. + */ +public class SplitOptions { + private final Optional maxSizeBytes; + private final Optional maxRowCount; + + private SplitOptions(Builder builder) { + this.maxSizeBytes = builder.maxSizeBytes; + this.maxRowCount = builder.maxRowCount; + } + + /** + * Returns the maximum size in bytes per split. + * + * @return the maximum size in bytes, or empty if not set + */ + public Optional getMaxSizeBytes() { + return maxSizeBytes; + } + + /** + * Returns the maximum number of rows per split. + * + * @return the maximum row count, or empty if not set + */ + public Optional getMaxRowCount() { + return maxRowCount; + } + + /** Builder for {@link SplitOptions}. */ + public static class Builder { + private Optional maxSizeBytes = Optional.empty(); + private Optional maxRowCount = Optional.empty(); + + /** + * Sets the maximum size in bytes per split. + * + *

The scanner estimates the row size from the output schema and calculates how many rows fit + * within this budget. + * + * @param maxSizeBytes the maximum size in bytes + * @return this builder + */ + public Builder maxSizeBytes(long maxSizeBytes) { + this.maxSizeBytes = Optional.of(maxSizeBytes); + return this; + } + + /** + * Sets the maximum number of rows per split. + * + * @param maxRowCount the maximum number of rows + * @return this builder + */ + public Builder maxRowCount(long maxRowCount) { + this.maxRowCount = Optional.of(maxRowCount); + return this; + } + + /** + * Builds a new {@link SplitOptions} instance. + * + * @return a new SplitOptions instance + */ + public SplitOptions build() { + return new SplitOptions(this); + } + } +} diff --git a/java/src/test/java/org/lance/ScannerTest.java b/java/src/test/java/org/lance/ScannerTest.java index 9fe844e9334..b2991d956e8 100644 --- a/java/src/test/java/org/lance/ScannerTest.java +++ b/java/src/test/java/org/lance/ScannerTest.java @@ -16,6 +16,9 @@ import org.lance.ipc.ColumnOrdering; import org.lance.ipc.LanceScanner; import org.lance.ipc.ScanOptions; +import org.lance.ipc.Split; +import org.lance.ipc.SplitFragment; +import org.lance.ipc.SplitOptions; import org.apache.arrow.dataset.scanner.Scanner; import org.apache.arrow.memory.BufferAllocator; @@ -554,4 +557,82 @@ private void validScanResult(Dataset dataset, int fragmentId, int rowCount) thro } } } + + @Test + void testPlanSplits(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("plan_splits").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 100; + try (Dataset dataset = testDataset.write(1, totalRows)) { + try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().build())) { + // Test planSplits without options + List splits = scanner.planSplits(); + assertFalse(splits.isEmpty(), "Should return at least one split"); + + // Verify each split has fragments with valid metadata + for (Split split : splits) { + List fragments = split.getFragments(); + assertFalse(fragments.isEmpty(), "Each split should have at least one fragment"); + for (SplitFragment sf : fragments) { + FragmentMetadata fm = sf.getFragment(); + assertTrue(fm.getId() >= 0, "Fragment ID should be non-negative"); + assertTrue(sf.getMaxRowCount() > 0, "Max row count should be positive"); + } + } + } + } + } + } + + @Test + void testPlanSplitsWithOptions(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("plan_splits_options").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 100; + try (Dataset dataset = testDataset.write(1, totalRows)) { + try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().build())) { + // Test planSplits with maxRowCount option + SplitOptions options = new SplitOptions.Builder().maxRowCount(20).build(); + List splits = scanner.planSplits(options); + assertFalse(splits.isEmpty(), "Should return at least one split"); + + // With max 20 rows per split and 100 total rows, + // we should have multiple splits + assertTrue(splits.size() >= 1, "Should have at least one split"); + + // Verify each split fragment respects the row count + for (Split split : splits) { + for (SplitFragment sf : split.getFragments()) { + assertTrue(sf.getMaxRowCount() > 0, "Max row count should be positive"); + } + } + } + } + } + } + + @Test + void testPlanSplitsWithMaxSizeBytes(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("plan_splits_max_bytes").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 100; + try (Dataset dataset = testDataset.write(1, totalRows)) { + try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().build())) { + // Test planSplits with maxSizeBytes option + SplitOptions options = new SplitOptions.Builder().maxSizeBytes(1024).build(); + List splits = scanner.planSplits(options); + assertFalse(splits.isEmpty(), "Should return at least one split"); + } + } + } + } }