Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public void close() throws IOException {
fragmentScanner.close();
}

private void addBlobVirtualColumns(
public static void addBlobVirtualColumns(
List<ColumnVector> fieldVectors, VectorSchemaRoot root, LanceInputPartition inputPartition) {
StructType schema = inputPartition.getSchema();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public LanceInputPartition getInputPartition() {
return inputPartition;
}

private static List<String> getColumnNames(StructType schema) {
public static List<String> getColumnNames(StructType schema) {
return Arrays.stream(schema.fields())
.map(StructField::name)
.filter(
Expand All @@ -173,13 +173,13 @@ private static List<String> getColumnNames(StructType schema) {
.collect(Collectors.toList());
}

private static boolean getWithRowId(StructType schema) {
public static boolean getWithRowId(StructType schema) {
return Arrays.stream(schema.fields())
.map(StructField::name)
.anyMatch(name -> name.equals(LanceConstant.ROW_ID));
}

private static boolean getWithRowAddress(StructType schema) {
public static boolean getWithRowAddress(StructType schema) {
return Arrays.stream(schema.fields())
.map(StructField::name)
.anyMatch(name -> name.equals(LanceConstant.ROW_ADDRESS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,138 @@
*/
package org.lance.spark.read;

import org.lance.Dataset;
import org.lance.ReadOptions;
import org.lance.ipc.LanceScanner;
import org.lance.ipc.ScanOptions;
import org.lance.namespace.LanceNamespaceStorageOptionsProvider;
import org.lance.spark.LanceRuntime;
import org.lance.spark.LanceSparkReadOptions;
import org.lance.spark.internal.LanceFragmentColumnarBatchScanner;
import org.lance.spark.internal.LanceFragmentScanner;
import org.lance.spark.vectorized.LanceArrowColumnVector;

import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.spark.sql.connector.read.PartitionReader;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class LanceColumnarPartitionReader implements PartitionReader<ColumnarBatch> {
private final LanceInputPartition inputPartition;
private final boolean useFilteredReadPlan;

// Fragment-based path fields
private int fragmentIndex;
private LanceFragmentColumnarBatchScanner fragmentReader;

// Filtered read plan path fields
private Dataset dataset;
private LanceScanner scanner;
private ArrowReader planArrowReader;
private boolean initialized;

private ColumnarBatch currentBatch;

public LanceColumnarPartitionReader(LanceInputPartition inputPartition) {
this.inputPartition = inputPartition;
this.useFilteredReadPlan = inputPartition.getLanceSplit().getFilteredReadPlan().isPresent();
this.fragmentIndex = 0;
this.initialized = false;
}

@Override
public boolean next() throws IOException {
if (useFilteredReadPlan) {
return nextFilteredReadPlan();
} else {
return nextFragmentBased();
}
}

private boolean nextFilteredReadPlan() throws IOException {
if (!initialized) {
initializeFilteredReadPlan();
initialized = true;
}
if (planArrowReader.loadNextBatch()) {
VectorSchemaRoot root = planArrowReader.getVectorSchemaRoot();
List<ColumnVector> fieldVectors =
root.getFieldVectors().stream()
.map(LanceArrowColumnVector::new)
.collect(Collectors.toList());

LanceFragmentColumnarBatchScanner.addBlobVirtualColumns(fieldVectors, root, inputPartition);

currentBatch =
new ColumnarBatch(fieldVectors.toArray(new ColumnVector[] {}), root.getRowCount());
return true;
}
return false;
}

private void initializeFilteredReadPlan() {
try {
LanceSparkReadOptions readOptions = inputPartition.getReadOptions();

Map<String, String> merged =
LanceRuntime.mergeStorageOptions(
readOptions.getStorageOptions(), inputPartition.getInitialStorageOptions());
LanceNamespaceStorageOptionsProvider provider =
LanceRuntime.getOrCreateStorageOptionsProvider(
inputPartition.getNamespaceImpl(),
inputPartition.getNamespaceProperties(),
readOptions.getTableId());

ReadOptions.Builder readOptionsBuilder = new ReadOptions.Builder().setStorageOptions(merged);
if (provider != null) {
readOptionsBuilder.setStorageOptionsProvider(provider);
}

dataset =
Dataset.open()
.allocator(LanceRuntime.allocator())
.uri(readOptions.getDatasetUri())
.readOptions(readOptionsBuilder.build())
.build();

ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
scanOptionsBuilder.columns(LanceFragmentScanner.getColumnNames(inputPartition.getSchema()));
if (inputPartition.getWhereCondition().isPresent()) {
scanOptionsBuilder.filter(inputPartition.getWhereCondition().get());
}
scanOptionsBuilder.batchSize(readOptions.getBatchSize());
scanOptionsBuilder.withRowId(LanceFragmentScanner.getWithRowId(inputPartition.getSchema()));
scanOptionsBuilder.withRowAddress(
LanceFragmentScanner.getWithRowAddress(inputPartition.getSchema()));
if (readOptions.getNearest() != null) {
scanOptionsBuilder.nearest(readOptions.getNearest());
}
if (inputPartition.getLimit().isPresent()) {
scanOptionsBuilder.limit(inputPartition.getLimit().get());
}
if (inputPartition.getOffset().isPresent()) {
scanOptionsBuilder.offset(inputPartition.getOffset().get());
}
if (inputPartition.getTopNSortOrders().isPresent()) {
scanOptionsBuilder.setColumnOrderings(inputPartition.getTopNSortOrders().get());
}

scanner = dataset.newScan(scanOptionsBuilder.build());
planArrowReader =
scanner.executeFilteredReadPlan(
inputPartition.getLanceSplit().getFilteredReadPlan().get());
} catch (Exception e) {
throw new RuntimeException("Failed to initialize filtered read plan", e);
}
}

private boolean nextFragmentBased() throws IOException {
if (loadNextBatchFromCurrentReader()) {
return true;
}
Expand Down Expand Up @@ -73,5 +185,18 @@ public void close() throws IOException {
throw new IOException(e);
}
}
if (planArrowReader != null) {
planArrowReader.close();
}
if (scanner != null) {
try {
scanner.close();
} catch (Exception e) {
throw new IOException(e);
}
}
if (dataset != null) {
dataset.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
package org.lance.spark.read;

import org.lance.Dataset;
import org.lance.ReadOptions;
import org.lance.ipc.FilteredReadPlan;
import org.lance.ipc.LanceScanner;
import org.lance.ipc.ScanOptions;
import org.lance.namespace.LanceNamespaceStorageOptionsProvider;
import org.lance.spark.LanceRuntime;
import org.lance.spark.LanceSparkReadOptions;
import org.lance.spark.vectorized.LanceArrowColumnVector;
Expand All @@ -32,6 +35,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;

/**
* Partition reader for pushed down aggregates. This reader computes the aggregate result directly
Expand Down Expand Up @@ -59,6 +63,59 @@ public boolean next() throws IOException {
}

private long computeCount() {
if (inputPartition.getLanceSplit().getFilteredReadPlan().isPresent()) {
return computeCountFilteredReadPlan();
}
return computeCountFragmentBased();
}

private long computeCountFilteredReadPlan() {
LanceSparkReadOptions readOptions = inputPartition.getReadOptions();
long totalCount = 0;

Map<String, String> merged =
LanceRuntime.mergeStorageOptions(
readOptions.getStorageOptions(), inputPartition.getInitialStorageOptions());
LanceNamespaceStorageOptionsProvider provider =
LanceRuntime.getOrCreateStorageOptionsProvider(
inputPartition.getNamespaceImpl(),
inputPartition.getNamespaceProperties(),
readOptions.getTableId());

ReadOptions.Builder readOptionsBuilder = new ReadOptions.Builder().setStorageOptions(merged);
if (provider != null) {
readOptionsBuilder.setStorageOptionsProvider(provider);
}

try (Dataset dataset =
Dataset.open()
.allocator(allocator)
.uri(readOptions.getDatasetUri())
.readOptions(readOptionsBuilder.build())
.build()) {
ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
if (inputPartition.getWhereCondition().isPresent()) {
scanOptionsBuilder.filter(inputPartition.getWhereCondition().get());
}
scanOptionsBuilder.withRowId(true);
scanOptionsBuilder.columns(Lists.newArrayList());

try (LanceScanner scanner = dataset.newScan(scanOptionsBuilder.build())) {
FilteredReadPlan plan = inputPartition.getLanceSplit().getFilteredReadPlan().get();
try (ArrowReader reader = scanner.executeFilteredReadPlan(plan)) {
while (reader.loadNextBatch()) {
totalCount += reader.getVectorSchemaRoot().getRowCount();
}
}
} catch (Exception e) {
throw new RuntimeException("Failed to execute filtered read plan", e);
}
}

return totalCount;
}

private long computeCountFragmentBased() {
// This reader is only used when there are filters (metadata-based count uses LocalScan)
LanceSparkReadOptions readOptions = inputPartition.getReadOptions();
long totalCount = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ public Batch toBatch() {

@Override
public InputPartition[] planInputPartitions() {
List<LanceSplit> splits = LanceSplit.generateLanceSplits(readOptions);
List<LanceSplit> splits =
LanceSplit.generateLanceSplits(
readOptions, schema, whereConditions, limit, offset, topNSortOrders);
return IntStream.range(0, splits.size())
.mapToObj(
i ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,20 @@
package org.lance.spark.read;

import org.lance.Dataset;
import org.lance.Fragment;
import org.lance.ipc.ColumnOrdering;
import org.lance.ipc.FilteredReadPlan;
import org.lance.ipc.LanceScanner;
import org.lance.ipc.ScanOptions;
import org.lance.ipc.Splits;
import org.lance.spark.LanceRuntime;
import org.lance.spark.LanceSparkReadOptions;
import org.lance.spark.internal.LanceFragmentScanner;
import org.lance.spark.utils.Optional;

import org.apache.spark.sql.types.StructType;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -27,21 +36,71 @@ public class LanceSplit implements Serializable {
private static final long serialVersionUID = 2983749283749283749L;

private final List<Integer> fragments;
private final FilteredReadPlan filteredReadPlan;

public LanceSplit(List<Integer> fragments) {
this.fragments = fragments;
this.filteredReadPlan = null;
}

public LanceSplit(FilteredReadPlan filteredReadPlan) {
this.fragments = new ArrayList<>(filteredReadPlan.getFragmentRanges().keySet());
this.filteredReadPlan = filteredReadPlan;
}

public List<Integer> getFragments() {
return fragments;
}

public static List<LanceSplit> generateLanceSplits(LanceSparkReadOptions readOptions) {
public Optional<FilteredReadPlan> getFilteredReadPlan() {
return Optional.ofNullable(filteredReadPlan);
}

public static List<LanceSplit> generateLanceSplits(
LanceSparkReadOptions readOptions,
StructType schema,
Optional<String> whereCondition,
Optional<Integer> limit,
Optional<Integer> offset,
Optional<List<ColumnOrdering>> topNSortOrders) {
try (Dataset dataset = openDataset(readOptions)) {
return dataset.getFragments().stream()
.map(Fragment::getId)
.map(id -> new LanceSplit(Collections.singletonList(id)))
.collect(Collectors.toList());
ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
scanOptionsBuilder.columns(LanceFragmentScanner.getColumnNames(schema));
if (whereCondition.isPresent()) {
scanOptionsBuilder.filter(whereCondition.get());
}
scanOptionsBuilder.batchSize(readOptions.getBatchSize());
scanOptionsBuilder.withRowId(LanceFragmentScanner.getWithRowId(schema));
scanOptionsBuilder.withRowAddress(LanceFragmentScanner.getWithRowAddress(schema));
if (readOptions.getNearest() != null) {
scanOptionsBuilder.nearest(readOptions.getNearest());
}
if (limit.isPresent()) {
scanOptionsBuilder.limit(limit.get());
}
if (offset.isPresent()) {
scanOptionsBuilder.offset(offset.get());
}
if (topNSortOrders.isPresent()) {
scanOptionsBuilder.setColumnOrderings(topNSortOrders.get());
}

try (LanceScanner scanner = dataset.newScan(scanOptionsBuilder.build())) {
Splits splits = scanner.planSplits(null);
if (splits.getFilteredReadPlans().isPresent()) {
return splits.getFilteredReadPlans().get().stream()
.map(LanceSplit::new)
.collect(Collectors.toList());
} else if (splits.getFragments().isPresent()) {
return splits.getFragments().get().stream()
.map(id -> new LanceSplit(Collections.singletonList(id)))
.collect(Collectors.toList());
} else {
return Collections.emptyList();
}
} catch (Exception e) {
throw new RuntimeException("Failed to plan splits", e);
}
}
}

Expand Down
Loading