From 2edc065b85ab8318a910be4e0ac7627b3c0dd884 Mon Sep 17 00:00:00 2001 From: shml Date: Mon, 13 Apr 2026 16:40:22 +0800 Subject: [PATCH] perf: use dataset-level scan for indexed vector search to avoid per-fragment redundancy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lance's IVF index is built globally across all fragments. When each fragment maps to a separate Spark partition, indexed vector search runs once per fragment instead of once per query — incurring N-fold task scheduling overhead and lower recall than a single global IVF search. Changes: - Add `LanceSplit.isIndexedVectorSearch()` to distinguish indexed vector search (nearest + useIndex=true) from brute-force KNN (useIndex=false). - For indexed search, merge all fragments into a single split and use `Dataset.newScan()` instead of `Fragment.newScan()` to execute a single global index search. Guard against empty datasets (no fragments). - For brute-force KNN, keep per-fragment splits for parallel scan and set `prefilter=true` on fragment scanners for correctness. - Skip SPJ partition key computation only for indexed vector search; brute-force KNN retains per-fragment splits so its partition key remains valid and SPJ can proceed normally. - Add tests covering planScan() split count: indexed search produces one split; brute-force KNN produces one split per fragment. --- .../spark/internal/LanceFragmentScanner.java | 54 +++++--- .../java/org/lance/spark/read/LanceScan.java | 6 +- .../java/org/lance/spark/read/LanceSplit.java | 33 ++++- ...parkConnectorReadWithVectorSearchTest.java | 69 +++++++---- .../read/LanceSplitVectorSearchTest.java | 117 ++++++++++++++++++ 5 files changed, 236 insertions(+), 43 deletions(-) create mode 100644 lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceSplitVectorSearchTest.java diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java index 5401fd9a..0f929b71 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java @@ -19,6 +19,7 @@ import org.lance.spark.LanceConstant; import org.lance.spark.LanceSparkReadOptions; import org.lance.spark.read.LanceInputPartition; +import org.lance.spark.read.LanceSplit; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.spark.sql.types.StructField; @@ -49,13 +50,6 @@ private LanceFragmentScanner( public static LanceFragmentScanner create(int fragmentId, LanceInputPartition inputPartition) { try { LanceSparkReadOptions readOptions = inputPartition.getReadOptions(); - Fragment fragment = - LanceDatasetCache.getFragment( - readOptions, - fragmentId, - inputPartition.getInitialStorageOptions(), - inputPartition.getNamespaceImpl(), - inputPartition.getNamespaceProperties()); ScanOptions.Builder scanOptions = new ScanOptions.Builder(); List projectedColumns = getColumnNames(inputPartition.getSchema()); if (projectedColumns.isEmpty() && inputPartition.getSchema().isEmpty()) { @@ -72,16 +66,6 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in scanOptions.filter(inputPartition.getWhereCondition().get()); } scanOptions.batchSize(readOptions.getBatchSize()); - if (readOptions.getNearest() != null) { - scanOptions.nearest(readOptions.getNearest()); - // We strictly set `prefilter = true` here to ensure query correctness. - // This is necessary due to the combination of two factors: - // 1. Spark currently performs the vector search by individually scanning each fragment. - // 2. Lance mandates that `prefilter` must be enabled for fragmented vector queries. - // If Spark's execution model or Lance's search functionality changes in the future, - // we need to revisit this. - scanOptions.prefilter(true); - } if (inputPartition.getLimit().isPresent()) { scanOptions.limit(inputPartition.getLimit().get()); } @@ -93,8 +77,40 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in } boolean withFragmentId = inputPartition.getSchema().getFieldIndex(LanceConstant.FRAGMENT_ID).nonEmpty(); - return new LanceFragmentScanner( - fragment.newScan(scanOptions.build()), fragmentId, withFragmentId, inputPartition); + + LanceScanner scanner; + if (LanceSplit.isIndexedVectorSearch(readOptions)) { + // Indexed vector search: use dataset-level scan to leverage the global index, + // avoiding N redundant global searches from per-fragment scans. + // Note: fragmentId is not used to scope the scan here — it is only read later + // via fragmentId() for the _fragid virtual column, whose actual values come + // from the scan results, not the split's representative fragment ID. + scanOptions.nearest(readOptions.getNearest()); + LanceDatasetCache.CachedDataset cached = + LanceDatasetCache.getDataset( + readOptions, + inputPartition.getInitialStorageOptions(), + inputPartition.getNamespaceImpl(), + inputPartition.getNamespaceProperties()); + scanner = cached.getDataset().newScan(scanOptions.build()); + } else { + // Regular scan or brute-force KNN (useIndex=false): per-fragment path. + // For brute-force KNN, we still pass the nearest query to the fragment scanner + // so it performs per-fragment KNN search (with prefilter for correctness). + if (readOptions.getNearest() != null) { + scanOptions.nearest(readOptions.getNearest()); + scanOptions.prefilter(true); + } + Fragment fragment = + LanceDatasetCache.getFragment( + readOptions, + fragmentId, + inputPartition.getInitialStorageOptions(), + inputPartition.getNamespaceImpl(), + inputPartition.getNamespaceProperties()); + scanner = fragment.newScan(scanOptions.build()); + } + return new LanceFragmentScanner(scanner, fragmentId, withFragmentId, inputPartition); } catch (Throwable throwable) { throw new RuntimeException(throwable); } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java index e2ec05a7..feca3087 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java @@ -167,7 +167,11 @@ public InputPartition[] planInputPartitions() { i -> { LanceSplit split = finalSplits.get(i); InternalRow partKeyRow = null; - if (partitionInfo != null) { + // Skip partition key only for indexed vector search: that single split uses + // a dataset-level scan and doesn't map to any one fragment's partition value. + // Brute-force KNN (useIndex=false) keeps per-fragment splits, so its + // partition key is valid and SPJ can proceed normally. + if (partitionInfo != null && !LanceSplit.isIndexedVectorSearch(readOptions)) { int fragId = split.getFragments().get(0); partKeyRow = partitionInfo.partitionKeyForFragment(fragId); } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceSplit.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceSplit.java index db9cd5bd..6715e3d2 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceSplit.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceSplit.java @@ -15,6 +15,7 @@ import org.lance.Dataset; import org.lance.Fragment; +import org.lance.ipc.Query; import org.lance.spark.LanceSparkReadOptions; import org.lance.spark.utils.Utils; @@ -76,13 +77,30 @@ public Map getFragmentRowCounts() { public static ScanPlanResult planScan(LanceSparkReadOptions readOptions) { try (Dataset dataset = Utils.openDatasetBuilder(readOptions).build()) { List fragments = dataset.getFragments(); - List splits = new ArrayList<>(fragments.size()); + List splits; Map fragmentRowCounts = new HashMap<>(fragments.size()); + List allFragmentIds = new ArrayList<>(fragments.size()); for (Fragment fragment : fragments) { int id = fragment.getId(); - splits.add(new LanceSplit(Collections.singletonList(id))); + allFragmentIds.add(id); fragmentRowCounts.put(id, fragment.metadata().getNumRows()); } + if (isIndexedVectorSearch(readOptions)) { + // Indexed vector search: merge into a single split with one fragment ID. + // Dataset-level scan is used for indexed search (see LanceFragmentScanner), + // so we don't need all fragment IDs — just one is enough to create the scanner. + // Brute-force KNN (useIndex=false) keeps per-fragment splits for parallelism. + splits = new ArrayList<>(1); + if (!allFragmentIds.isEmpty()) { + splits.add(new LanceSplit(Collections.singletonList(allFragmentIds.get(0)))); + } + } else { + // Non-vector scan: keep per-fragment parallelism + splits = new ArrayList<>(fragments.size()); + for (int id : allFragmentIds) { + splits.add(new LanceSplit(Collections.singletonList(id))); + } + } long resolvedVersion = dataset.getVersion().getId(); return new ScanPlanResult(splits, resolvedVersion, fragmentRowCounts); } @@ -95,4 +113,15 @@ public static ScanPlanResult planScan(LanceSparkReadOptions readOptions) { public static List generateLanceSplits(LanceSparkReadOptions readOptions) { return planScan(readOptions).getSplits(); } + + /** + * Returns true when the read options specify an indexed vector search (nearest query with + * useIndex=true). In this mode, a single split with dataset-level scan is used to leverage the + * global IVF index. Brute-force KNN (useIndex=false) returns false and retains per-fragment + * splits for parallelism. + */ + public static boolean isIndexedVectorSearch(LanceSparkReadOptions readOptions) { + Query nearest = readOptions.getNearest(); + return nearest != null && nearest.isUseIndex(); + } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseSparkConnectorReadWithVectorSearchTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseSparkConnectorReadWithVectorSearchTest.java index 6c116ec6..963fc5ff 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseSparkConnectorReadWithVectorSearchTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseSparkConnectorReadWithVectorSearchTest.java @@ -34,28 +34,20 @@ import static org.junit.jupiter.api.Assertions.*; -/* - *The test logic is same with org.lance.VectorSearchTest.test_knn - */ - public abstract class BaseSparkConnectorReadWithVectorSearchTest { private static SparkSession spark; private static String dbPath; - private static Dataset data; + + // test_dataset5 has 5 fragments and no pre-built vector index. + private static Dataset indexedData; // useIndex=true → single dataset-level scan + private static Dataset bruteForceData; // useIndex=false → per-fragment parallel scan @BeforeAll static void setup() { - - Query.Builder builder = new Query.Builder(); float[] key = new float[32]; for (int i = 0; i < 32; i++) { key[i] = (float) (i + 32); } - builder.setK(1); - builder.setColumn("vec"); - builder.setKey(key); - builder.setUseIndex(true); - builder.setDistanceType(DistanceType.L2); spark = SparkSession.builder() @@ -64,16 +56,39 @@ static void setup() { .config("spark.sql.catalog.lance", "org.lance.spark.LanceNamespaceSparkCatalog") .getOrCreate(); dbPath = TestUtils.TestTable1Config.dbPath; - data = + String datasetUri = TestUtils.getDatasetUri(dbPath, "test_dataset5"); + + Query.Builder indexedBuilder = new Query.Builder(); + indexedBuilder.setK(1); + indexedBuilder.setColumn("vec"); + indexedBuilder.setKey(key); + indexedBuilder.setUseIndex(true); + indexedBuilder.setDistanceType(DistanceType.L2); + indexedData = spark .read() .format(LanceDataSource.name) - .option(LanceSparkReadOptions.CONFIG_NEAREST, QueryUtils.queryToString(builder.build())) .option( - LanceSparkReadOptions.CONFIG_DATASET_URI, - TestUtils.getDatasetUri(dbPath, "test_dataset5")) + LanceSparkReadOptions.CONFIG_NEAREST, + QueryUtils.queryToString(indexedBuilder.build())) + .option(LanceSparkReadOptions.CONFIG_DATASET_URI, datasetUri) + .load(); + + Query.Builder bruteForceBuilder = new Query.Builder(); + bruteForceBuilder.setK(1); + bruteForceBuilder.setColumn("vec"); + bruteForceBuilder.setKey(key); + bruteForceBuilder.setUseIndex(false); + bruteForceBuilder.setDistanceType(DistanceType.L2); + bruteForceData = + spark + .read() + .format(LanceDataSource.name) + .option( + LanceSparkReadOptions.CONFIG_NEAREST, + QueryUtils.queryToString(bruteForceBuilder.build())) + .option(LanceSparkReadOptions.CONFIG_DATASET_URI, datasetUri) .load(); - data.createOrReplaceTempView("test_dataset5"); } @AfterAll @@ -84,12 +99,24 @@ static void tearDown() { } @Test - public void validateData() { + public void testIndexedSearchReturnsGlobalTopK() { + // useIndex=true uses a single dataset-level scan, so k=1 returns exactly 1 row + // globally — the nearest neighbor across all fragments combined. + List rows = indexedData.collectAsList(); + assertEquals(1, rows.size(), "Indexed k=1 search must return exactly 1 row globally"); + assertEquals(1, rows.get(0).getInt(0), "Unexpected value in 'i' column"); + } + + @Test + public void testBruteForceSearchReturnsPerFragmentCandidates() { + // useIndex=false keeps per-fragment splits for parallel brute-force scan. + // With k=1 and 5 fragments, each fragment returns its local top-1, + // yielding 5 candidate rows for the caller to aggregate. Set expectedI = new HashSet<>(Arrays.asList(1, 81, 161, 241, 321)); Set actualI = new HashSet<>(); - List rows = data.collectAsList(); - for (int i = 0; i < rows.size(); i++) { - actualI.add(rows.get(i).getInt(0)); + List rows = bruteForceData.collectAsList(); + for (Row row : rows) { + actualI.add(row.getInt(0)); } assertEquals(expectedI, actualI, "Unexpected values in 'i' column"); } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceSplitVectorSearchTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceSplitVectorSearchTest.java new file mode 100644 index 00000000..7e3b2c6f --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceSplitVectorSearchTest.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read; + +import org.lance.index.DistanceType; +import org.lance.ipc.Query; +import org.lance.spark.LanceSparkReadOptions; +import org.lance.spark.TestUtils; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class LanceSplitVectorSearchTest { + + private static final int TEST_DATASET5_FRAGMENT_COUNT = 5; + + private static LanceSparkReadOptions optionsWithNearest(boolean useIndex) { + Query.Builder builder = new Query.Builder(); + builder.setK(10); + builder.setColumn("vector"); + builder.setKey(new float[] {1.0f, 2.0f, 3.0f}); + builder.setDistanceType(DistanceType.L2); + builder.setUseIndex(useIndex); + return LanceSparkReadOptions.builder() + .datasetUri("s3://bucket/path") + .nearest(builder.build()) + .build(); + } + + private static LanceSparkReadOptions optionsWithoutNearest() { + return LanceSparkReadOptions.builder().datasetUri("s3://bucket/path").build(); + } + + /** Builds read options backed by test_dataset5 (5 fragments, has 'vec' column). */ + private static LanceSparkReadOptions dataset5OptionsWithNearest(boolean useIndex) { + float[] key = new float[32]; + for (int i = 0; i < 32; i++) { + key[i] = (float) (i + 32); + } + Query.Builder builder = new Query.Builder(); + builder.setK(1); + builder.setColumn("vec"); + builder.setKey(key); + builder.setDistanceType(DistanceType.L2); + builder.setUseIndex(useIndex); + String datasetUri = TestUtils.getDatasetUri(TestUtils.TestTable1Config.dbPath, "test_dataset5"); + return LanceSparkReadOptions.builder().datasetUri(datasetUri).nearest(builder.build()).build(); + } + + // --- isIndexedVectorSearch --- + + @Test + public void testIsIndexedVectorSearchWithUseIndexTrue() { + LanceSparkReadOptions options = optionsWithNearest(true); + assertTrue(LanceSplit.isIndexedVectorSearch(options)); + } + + @Test + public void testIsIndexedVectorSearchWithUseIndexFalse() { + LanceSparkReadOptions options = optionsWithNearest(false); + assertFalse(LanceSplit.isIndexedVectorSearch(options)); + } + + @Test + public void testIsIndexedVectorSearchWithoutNearest() { + LanceSparkReadOptions options = optionsWithoutNearest(); + assertFalse(LanceSplit.isIndexedVectorSearch(options)); + } + + // --- planScan split strategy --- + + @Test + public void testIndexedVectorSearchProducesSingleSplit() { + // Indexed vector search (useIndex=true) must produce exactly one split so that + // LanceFragmentScanner runs a single dataset-level scan instead of N redundant + // per-fragment global index searches. + LanceSplit.ScanPlanResult result = + LanceSplit.planScan(dataset5OptionsWithNearest(/* useIndex= */ true)); + List splits = result.getSplits(); + assertEquals(1, splits.size(), "Indexed vector search must produce exactly one split"); + assertEquals( + 1, + splits.get(0).getFragments().size(), + "The single split must carry exactly one representative fragment ID"); + } + + @Test + public void testBruteForceKNNProducesPerFragmentSplits() { + // Brute-force KNN (useIndex=false) must keep per-fragment splits so that + // each executor can scan its fragment in parallel. + LanceSplit.ScanPlanResult result = + LanceSplit.planScan(dataset5OptionsWithNearest(/* useIndex= */ false)); + List splits = result.getSplits(); + assertEquals( + TEST_DATASET5_FRAGMENT_COUNT, + splits.size(), + "Brute-force KNN must produce one split per fragment"); + for (LanceSplit split : splits) { + assertEquals( + 1, split.getFragments().size(), "Each brute-force KNN split must map to one fragment"); + } + } +}