Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.lance.spark.LanceConstant;
import org.lance.spark.LanceSparkReadOptions;
import org.lance.spark.read.LanceInputPartition;
import org.lance.spark.read.LanceSplit;

import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.spark.sql.types.StructField;
Expand Down Expand Up @@ -49,13 +50,6 @@ private LanceFragmentScanner(
public static LanceFragmentScanner create(int fragmentId, LanceInputPartition inputPartition) {
try {
LanceSparkReadOptions readOptions = inputPartition.getReadOptions();
Fragment fragment =
LanceDatasetCache.getFragment(
readOptions,
fragmentId,
inputPartition.getInitialStorageOptions(),
inputPartition.getNamespaceImpl(),
inputPartition.getNamespaceProperties());
ScanOptions.Builder scanOptions = new ScanOptions.Builder();
List<String> projectedColumns = getColumnNames(inputPartition.getSchema());
if (projectedColumns.isEmpty() && inputPartition.getSchema().isEmpty()) {
Expand All @@ -72,16 +66,6 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in
scanOptions.filter(inputPartition.getWhereCondition().get());
}
scanOptions.batchSize(readOptions.getBatchSize());
if (readOptions.getNearest() != null) {
scanOptions.nearest(readOptions.getNearest());
// We strictly set `prefilter = true` here to ensure query correctness.
// This is necessary due to the combination of two factors:
// 1. Spark currently performs the vector search by individually scanning each fragment.
// 2. Lance mandates that `prefilter` must be enabled for fragmented vector queries.
// If Spark's execution model or Lance's search functionality changes in the future,
// we need to revisit this.
scanOptions.prefilter(true);
}
if (inputPartition.getLimit().isPresent()) {
scanOptions.limit(inputPartition.getLimit().get());
}
Expand All @@ -93,8 +77,40 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in
}
boolean withFragmentId =
inputPartition.getSchema().getFieldIndex(LanceConstant.FRAGMENT_ID).nonEmpty();
return new LanceFragmentScanner(
fragment.newScan(scanOptions.build()), fragmentId, withFragmentId, inputPartition);

LanceScanner scanner;
if (LanceSplit.isIndexedVectorSearch(readOptions)) {
// Indexed vector search: use dataset-level scan to leverage the global index,
// avoiding N redundant global searches from per-fragment scans.
// Note: fragmentId is not used to scope the scan here — it is only read later
// via fragmentId() for the _fragid virtual column, whose actual values come
// from the scan results, not the split's representative fragment ID.
scanOptions.nearest(readOptions.getNearest());
LanceDatasetCache.CachedDataset cached =
LanceDatasetCache.getDataset(
readOptions,
inputPartition.getInitialStorageOptions(),
inputPartition.getNamespaceImpl(),
inputPartition.getNamespaceProperties());
scanner = cached.getDataset().newScan(scanOptions.build());
} else {
// Regular scan or brute-force KNN (useIndex=false): per-fragment path.
// For brute-force KNN, we still pass the nearest query to the fragment scanner
// so it performs per-fragment KNN search (with prefilter for correctness).
if (readOptions.getNearest() != null) {
scanOptions.nearest(readOptions.getNearest());
scanOptions.prefilter(true);
}
Fragment fragment =
LanceDatasetCache.getFragment(
readOptions,
fragmentId,
inputPartition.getInitialStorageOptions(),
inputPartition.getNamespaceImpl(),
inputPartition.getNamespaceProperties());
scanner = fragment.newScan(scanOptions.build());
}
return new LanceFragmentScanner(scanner, fragmentId, withFragmentId, inputPartition);
} catch (Throwable throwable) {
throw new RuntimeException(throwable);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ public InputPartition[] planInputPartitions() {
i -> {
LanceSplit split = finalSplits.get(i);
InternalRow partKeyRow = null;
if (partitionInfo != null) {
// Skip partition key only for indexed vector search: that single split uses
// a dataset-level scan and doesn't map to any one fragment's partition value.
// Brute-force KNN (useIndex=false) keeps per-fragment splits, so its
// partition key is valid and SPJ can proceed normally.
if (partitionInfo != null && !LanceSplit.isIndexedVectorSearch(readOptions)) {
int fragId = split.getFragments().get(0);
partKeyRow = partitionInfo.partitionKeyForFragment(fragId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import org.lance.Dataset;
import org.lance.Fragment;
import org.lance.ipc.Query;
import org.lance.spark.LanceSparkReadOptions;
import org.lance.spark.utils.Utils;

Expand Down Expand Up @@ -76,13 +77,30 @@ public Map<Integer, Long> getFragmentRowCounts() {
public static ScanPlanResult planScan(LanceSparkReadOptions readOptions) {
try (Dataset dataset = Utils.openDatasetBuilder(readOptions).build()) {
List<Fragment> fragments = dataset.getFragments();
List<LanceSplit> splits = new ArrayList<>(fragments.size());
List<LanceSplit> splits;
Map<Integer, Long> fragmentRowCounts = new HashMap<>(fragments.size());
List<Integer> allFragmentIds = new ArrayList<>(fragments.size());
for (Fragment fragment : fragments) {
int id = fragment.getId();
splits.add(new LanceSplit(Collections.singletonList(id)));
allFragmentIds.add(id);
fragmentRowCounts.put(id, fragment.metadata().getNumRows());
}
if (isIndexedVectorSearch(readOptions)) {
// Indexed vector search: merge into a single split with one fragment ID.
// Dataset-level scan is used for indexed search (see LanceFragmentScanner),
// so we don't need all fragment IDs — just one is enough to create the scanner.
// Brute-force KNN (useIndex=false) keeps per-fragment splits for parallelism.
splits = new ArrayList<>(1);
if (!allFragmentIds.isEmpty()) {
splits.add(new LanceSplit(Collections.singletonList(allFragmentIds.get(0))));
}
} else {
// Non-vector scan: keep per-fragment parallelism
splits = new ArrayList<>(fragments.size());
for (int id : allFragmentIds) {
splits.add(new LanceSplit(Collections.singletonList(id)));
}
}
long resolvedVersion = dataset.getVersion().getId();
return new ScanPlanResult(splits, resolvedVersion, fragmentRowCounts);
}
Expand All @@ -95,4 +113,15 @@ public static ScanPlanResult planScan(LanceSparkReadOptions readOptions) {
public static List<LanceSplit> generateLanceSplits(LanceSparkReadOptions readOptions) {
return planScan(readOptions).getSplits();
}

/**
* Returns true when the read options specify an indexed vector search (nearest query with
* useIndex=true). In this mode, a single split with dataset-level scan is used to leverage the
* global IVF index. Brute-force KNN (useIndex=false) returns false and retains per-fragment
* splits for parallelism.
*/
public static boolean isIndexedVectorSearch(LanceSparkReadOptions readOptions) {
Query nearest = readOptions.getNearest();
return nearest != null && nearest.isUseIndex();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,20 @@

import static org.junit.jupiter.api.Assertions.*;

/*
*The test logic is same with org.lance.VectorSearchTest.test_knn
*/

public abstract class BaseSparkConnectorReadWithVectorSearchTest {
private static SparkSession spark;
private static String dbPath;
private static Dataset<Row> data;

// test_dataset5 has 5 fragments and no pre-built vector index.
private static Dataset<Row> indexedData; // useIndex=true → single dataset-level scan
private static Dataset<Row> bruteForceData; // useIndex=false → per-fragment parallel scan

@BeforeAll
static void setup() {

Query.Builder builder = new Query.Builder();
float[] key = new float[32];
for (int i = 0; i < 32; i++) {
key[i] = (float) (i + 32);
}
builder.setK(1);
builder.setColumn("vec");
builder.setKey(key);
builder.setUseIndex(true);
builder.setDistanceType(DistanceType.L2);

spark =
SparkSession.builder()
Expand All @@ -64,16 +56,39 @@ static void setup() {
.config("spark.sql.catalog.lance", "org.lance.spark.LanceNamespaceSparkCatalog")
.getOrCreate();
dbPath = TestUtils.TestTable1Config.dbPath;
data =
String datasetUri = TestUtils.getDatasetUri(dbPath, "test_dataset5");

Query.Builder indexedBuilder = new Query.Builder();
indexedBuilder.setK(1);
indexedBuilder.setColumn("vec");
indexedBuilder.setKey(key);
indexedBuilder.setUseIndex(true);
indexedBuilder.setDistanceType(DistanceType.L2);
indexedData =
spark
.read()
.format(LanceDataSource.name)
.option(LanceSparkReadOptions.CONFIG_NEAREST, QueryUtils.queryToString(builder.build()))
.option(
LanceSparkReadOptions.CONFIG_DATASET_URI,
TestUtils.getDatasetUri(dbPath, "test_dataset5"))
LanceSparkReadOptions.CONFIG_NEAREST,
QueryUtils.queryToString(indexedBuilder.build()))
.option(LanceSparkReadOptions.CONFIG_DATASET_URI, datasetUri)
.load();

Query.Builder bruteForceBuilder = new Query.Builder();
bruteForceBuilder.setK(1);
bruteForceBuilder.setColumn("vec");
bruteForceBuilder.setKey(key);
bruteForceBuilder.setUseIndex(false);
bruteForceBuilder.setDistanceType(DistanceType.L2);
bruteForceData =
spark
.read()
.format(LanceDataSource.name)
.option(
LanceSparkReadOptions.CONFIG_NEAREST,
QueryUtils.queryToString(bruteForceBuilder.build()))
.option(LanceSparkReadOptions.CONFIG_DATASET_URI, datasetUri)
.load();
data.createOrReplaceTempView("test_dataset5");
}

@AfterAll
Expand All @@ -84,12 +99,24 @@ static void tearDown() {
}

@Test
public void validateData() {
public void testIndexedSearchReturnsGlobalTopK() {
// useIndex=true uses a single dataset-level scan, so k=1 returns exactly 1 row
// globally — the nearest neighbor across all fragments combined.
List<Row> rows = indexedData.collectAsList();
assertEquals(1, rows.size(), "Indexed k=1 search must return exactly 1 row globally");
assertEquals(1, rows.get(0).getInt(0), "Unexpected value in 'i' column");
}

@Test
public void testBruteForceSearchReturnsPerFragmentCandidates() {
// useIndex=false keeps per-fragment splits for parallel brute-force scan.
// With k=1 and 5 fragments, each fragment returns its local top-1,
// yielding 5 candidate rows for the caller to aggregate.
Set<Integer> expectedI = new HashSet<>(Arrays.asList(1, 81, 161, 241, 321));
Set<Integer> actualI = new HashSet<>();
List<Row> rows = data.collectAsList();
for (int i = 0; i < rows.size(); i++) {
actualI.add(rows.get(i).getInt(0));
List<Row> rows = bruteForceData.collectAsList();
for (Row row : rows) {
actualI.add(row.getInt(0));
}
assertEquals(expectedI, actualI, "Unexpected values in 'i' column");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.lance.spark.read;

import org.lance.index.DistanceType;
import org.lance.ipc.Query;
import org.lance.spark.LanceSparkReadOptions;
import org.lance.spark.TestUtils;

import org.junit.jupiter.api.Test;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

public class LanceSplitVectorSearchTest {

private static final int TEST_DATASET5_FRAGMENT_COUNT = 5;

private static LanceSparkReadOptions optionsWithNearest(boolean useIndex) {
Query.Builder builder = new Query.Builder();
builder.setK(10);
builder.setColumn("vector");
builder.setKey(new float[] {1.0f, 2.0f, 3.0f});
builder.setDistanceType(DistanceType.L2);
builder.setUseIndex(useIndex);
return LanceSparkReadOptions.builder()
.datasetUri("s3://bucket/path")
.nearest(builder.build())
.build();
}

private static LanceSparkReadOptions optionsWithoutNearest() {
return LanceSparkReadOptions.builder().datasetUri("s3://bucket/path").build();
}

/** Builds read options backed by test_dataset5 (5 fragments, has 'vec' column). */
private static LanceSparkReadOptions dataset5OptionsWithNearest(boolean useIndex) {
float[] key = new float[32];
for (int i = 0; i < 32; i++) {
key[i] = (float) (i + 32);
}
Query.Builder builder = new Query.Builder();
builder.setK(1);
builder.setColumn("vec");
builder.setKey(key);
builder.setDistanceType(DistanceType.L2);
builder.setUseIndex(useIndex);
String datasetUri = TestUtils.getDatasetUri(TestUtils.TestTable1Config.dbPath, "test_dataset5");
return LanceSparkReadOptions.builder().datasetUri(datasetUri).nearest(builder.build()).build();
}

// --- isIndexedVectorSearch ---

@Test
public void testIsIndexedVectorSearchWithUseIndexTrue() {
LanceSparkReadOptions options = optionsWithNearest(true);
assertTrue(LanceSplit.isIndexedVectorSearch(options));
}

@Test
public void testIsIndexedVectorSearchWithUseIndexFalse() {
LanceSparkReadOptions options = optionsWithNearest(false);
assertFalse(LanceSplit.isIndexedVectorSearch(options));
}

@Test
public void testIsIndexedVectorSearchWithoutNearest() {
LanceSparkReadOptions options = optionsWithoutNearest();
assertFalse(LanceSplit.isIndexedVectorSearch(options));
}

// --- planScan split strategy ---

@Test
public void testIndexedVectorSearchProducesSingleSplit() {
// Indexed vector search (useIndex=true) must produce exactly one split so that
// LanceFragmentScanner runs a single dataset-level scan instead of N redundant
// per-fragment global index searches.
LanceSplit.ScanPlanResult result =
LanceSplit.planScan(dataset5OptionsWithNearest(/* useIndex= */ true));
List<LanceSplit> splits = result.getSplits();
assertEquals(1, splits.size(), "Indexed vector search must produce exactly one split");
assertEquals(
1,
splits.get(0).getFragments().size(),
"The single split must carry exactly one representative fragment ID");
}

@Test
public void testBruteForceKNNProducesPerFragmentSplits() {
// Brute-force KNN (useIndex=false) must keep per-fragment splits so that
// each executor can scan its fragment in parallel.
LanceSplit.ScanPlanResult result =
LanceSplit.planScan(dataset5OptionsWithNearest(/* useIndex= */ false));
List<LanceSplit> splits = result.getSplits();
assertEquals(
TEST_DATASET5_FRAGMENT_COUNT,
splits.size(),
"Brute-force KNN must produce one split per fragment");
for (LanceSplit split : splits) {
assertEquals(
1, split.getFragments().size(), "Each brute-force KNN split must map to one fragment");
}
}
}
Loading