diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentColumnarBatchScanner.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentColumnarBatchScanner.java index 34b4d610..7e555c9c 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentColumnarBatchScanner.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentColumnarBatchScanner.java @@ -93,7 +93,7 @@ public void close() throws IOException { fragmentScanner.close(); } - private void addBlobVirtualColumns( + public static void addBlobVirtualColumns( List fieldVectors, VectorSchemaRoot root, LanceInputPartition inputPartition) { StructType schema = inputPartition.getSchema(); diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java index f5de7a4c..79f71df7 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java @@ -160,7 +160,7 @@ public LanceInputPartition getInputPartition() { return inputPartition; } - private static List getColumnNames(StructType schema) { + public static List getColumnNames(StructType schema) { return Arrays.stream(schema.fields()) .map(StructField::name) .filter( @@ -173,13 +173,13 @@ private static List getColumnNames(StructType schema) { .collect(Collectors.toList()); } - private static boolean getWithRowId(StructType schema) { + public static boolean getWithRowId(StructType schema) { return Arrays.stream(schema.fields()) .map(StructField::name) .anyMatch(name -> name.equals(LanceConstant.ROW_ID)); } - private static boolean getWithRowAddress(StructType schema) { + public static boolean getWithRowAddress(StructType schema) { return Arrays.stream(schema.fields()) .map(StructField::name) .anyMatch(name -> name.equals(LanceConstant.ROW_ADDRESS)); diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceColumnarPartitionReader.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceColumnarPartitionReader.java index df0addfd..ec35f7a1 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceColumnarPartitionReader.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceColumnarPartitionReader.java @@ -13,26 +13,138 @@ */ package org.lance.spark.read; +import org.lance.Dataset; +import org.lance.ReadOptions; +import org.lance.ipc.LanceScanner; +import org.lance.ipc.ScanOptions; +import org.lance.namespace.LanceNamespaceStorageOptionsProvider; +import org.lance.spark.LanceRuntime; +import org.lance.spark.LanceSparkReadOptions; import org.lance.spark.internal.LanceFragmentColumnarBatchScanner; +import org.lance.spark.internal.LanceFragmentScanner; +import org.lance.spark.vectorized.LanceArrowColumnVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; public class LanceColumnarPartitionReader implements PartitionReader { private final LanceInputPartition inputPartition; + private final boolean useFilteredReadPlan; + + // Fragment-based path fields private int fragmentIndex; private LanceFragmentColumnarBatchScanner fragmentReader; + + // Filtered read plan path fields + private Dataset dataset; + private LanceScanner scanner; + private ArrowReader planArrowReader; + private boolean initialized; + private ColumnarBatch currentBatch; public LanceColumnarPartitionReader(LanceInputPartition inputPartition) { this.inputPartition = inputPartition; + this.useFilteredReadPlan = inputPartition.getLanceSplit().getFilteredReadPlan().isPresent(); this.fragmentIndex = 0; + this.initialized = false; } @Override public boolean next() throws IOException { + if (useFilteredReadPlan) { + return nextFilteredReadPlan(); + } else { + return nextFragmentBased(); + } + } + + private boolean nextFilteredReadPlan() throws IOException { + if (!initialized) { + initializeFilteredReadPlan(); + initialized = true; + } + if (planArrowReader.loadNextBatch()) { + VectorSchemaRoot root = planArrowReader.getVectorSchemaRoot(); + List fieldVectors = + root.getFieldVectors().stream() + .map(LanceArrowColumnVector::new) + .collect(Collectors.toList()); + + LanceFragmentColumnarBatchScanner.addBlobVirtualColumns(fieldVectors, root, inputPartition); + + currentBatch = + new ColumnarBatch(fieldVectors.toArray(new ColumnVector[] {}), root.getRowCount()); + return true; + } + return false; + } + + private void initializeFilteredReadPlan() { + try { + LanceSparkReadOptions readOptions = inputPartition.getReadOptions(); + + Map merged = + LanceRuntime.mergeStorageOptions( + readOptions.getStorageOptions(), inputPartition.getInitialStorageOptions()); + LanceNamespaceStorageOptionsProvider provider = + LanceRuntime.getOrCreateStorageOptionsProvider( + inputPartition.getNamespaceImpl(), + inputPartition.getNamespaceProperties(), + readOptions.getTableId()); + + ReadOptions.Builder readOptionsBuilder = new ReadOptions.Builder().setStorageOptions(merged); + if (provider != null) { + readOptionsBuilder.setStorageOptionsProvider(provider); + } + + dataset = + Dataset.open() + .allocator(LanceRuntime.allocator()) + .uri(readOptions.getDatasetUri()) + .readOptions(readOptionsBuilder.build()) + .build(); + + ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); + scanOptionsBuilder.columns(LanceFragmentScanner.getColumnNames(inputPartition.getSchema())); + if (inputPartition.getWhereCondition().isPresent()) { + scanOptionsBuilder.filter(inputPartition.getWhereCondition().get()); + } + scanOptionsBuilder.batchSize(readOptions.getBatchSize()); + scanOptionsBuilder.withRowId(LanceFragmentScanner.getWithRowId(inputPartition.getSchema())); + scanOptionsBuilder.withRowAddress( + LanceFragmentScanner.getWithRowAddress(inputPartition.getSchema())); + if (readOptions.getNearest() != null) { + scanOptionsBuilder.nearest(readOptions.getNearest()); + } + if (inputPartition.getLimit().isPresent()) { + scanOptionsBuilder.limit(inputPartition.getLimit().get()); + } + if (inputPartition.getOffset().isPresent()) { + scanOptionsBuilder.offset(inputPartition.getOffset().get()); + } + if (inputPartition.getTopNSortOrders().isPresent()) { + scanOptionsBuilder.setColumnOrderings(inputPartition.getTopNSortOrders().get()); + } + + scanner = dataset.newScan(scanOptionsBuilder.build()); + planArrowReader = + scanner.executeFilteredReadPlan( + inputPartition.getLanceSplit().getFilteredReadPlan().get()); + } catch (Exception e) { + throw new RuntimeException("Failed to initialize filtered read plan", e); + } + } + + private boolean nextFragmentBased() throws IOException { if (loadNextBatchFromCurrentReader()) { return true; } @@ -73,5 +185,18 @@ public void close() throws IOException { throw new IOException(e); } } + if (planArrowReader != null) { + planArrowReader.close(); + } + if (scanner != null) { + try { + scanner.close(); + } catch (Exception e) { + throw new IOException(e); + } + } + if (dataset != null) { + dataset.close(); + } } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceCountStarPartitionReader.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceCountStarPartitionReader.java index 17dbed21..31d34e81 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceCountStarPartitionReader.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceCountStarPartitionReader.java @@ -14,8 +14,11 @@ package org.lance.spark.read; import org.lance.Dataset; +import org.lance.ReadOptions; +import org.lance.ipc.FilteredReadPlan; import org.lance.ipc.LanceScanner; import org.lance.ipc.ScanOptions; +import org.lance.namespace.LanceNamespaceStorageOptionsProvider; import org.lance.spark.LanceRuntime; import org.lance.spark.LanceSparkReadOptions; import org.lance.spark.vectorized.LanceArrowColumnVector; @@ -32,6 +35,7 @@ import java.io.IOException; import java.util.List; +import java.util.Map; /** * Partition reader for pushed down aggregates. This reader computes the aggregate result directly @@ -59,6 +63,59 @@ public boolean next() throws IOException { } private long computeCount() { + if (inputPartition.getLanceSplit().getFilteredReadPlan().isPresent()) { + return computeCountFilteredReadPlan(); + } + return computeCountFragmentBased(); + } + + private long computeCountFilteredReadPlan() { + LanceSparkReadOptions readOptions = inputPartition.getReadOptions(); + long totalCount = 0; + + Map merged = + LanceRuntime.mergeStorageOptions( + readOptions.getStorageOptions(), inputPartition.getInitialStorageOptions()); + LanceNamespaceStorageOptionsProvider provider = + LanceRuntime.getOrCreateStorageOptionsProvider( + inputPartition.getNamespaceImpl(), + inputPartition.getNamespaceProperties(), + readOptions.getTableId()); + + ReadOptions.Builder readOptionsBuilder = new ReadOptions.Builder().setStorageOptions(merged); + if (provider != null) { + readOptionsBuilder.setStorageOptionsProvider(provider); + } + + try (Dataset dataset = + Dataset.open() + .allocator(allocator) + .uri(readOptions.getDatasetUri()) + .readOptions(readOptionsBuilder.build()) + .build()) { + ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); + if (inputPartition.getWhereCondition().isPresent()) { + scanOptionsBuilder.filter(inputPartition.getWhereCondition().get()); + } + scanOptionsBuilder.withRowId(true); + scanOptionsBuilder.columns(Lists.newArrayList()); + + try (LanceScanner scanner = dataset.newScan(scanOptionsBuilder.build())) { + FilteredReadPlan plan = inputPartition.getLanceSplit().getFilteredReadPlan().get(); + try (ArrowReader reader = scanner.executeFilteredReadPlan(plan)) { + while (reader.loadNextBatch()) { + totalCount += reader.getVectorSchemaRoot().getRowCount(); + } + } + } catch (Exception e) { + throw new RuntimeException("Failed to execute filtered read plan", e); + } + } + + return totalCount; + } + + private long computeCountFragmentBased() { // This reader is only used when there are filters (metadata-based count uses LocalScan) LanceSparkReadOptions readOptions = inputPartition.getReadOptions(); long totalCount = 0; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java index b2793aa8..e57cc1ea 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java @@ -96,7 +96,9 @@ public Batch toBatch() { @Override public InputPartition[] planInputPartitions() { - List splits = LanceSplit.generateLanceSplits(readOptions); + List splits = + LanceSplit.generateLanceSplits( + readOptions, schema, whereConditions, limit, offset, topNSortOrders); return IntStream.range(0, splits.size()) .mapToObj( i -> diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceSplit.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceSplit.java index 8ec4461e..1ba61d33 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceSplit.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceSplit.java @@ -14,11 +14,20 @@ package org.lance.spark.read; import org.lance.Dataset; -import org.lance.Fragment; +import org.lance.ipc.ColumnOrdering; +import org.lance.ipc.FilteredReadPlan; +import org.lance.ipc.LanceScanner; +import org.lance.ipc.ScanOptions; +import org.lance.ipc.Splits; import org.lance.spark.LanceRuntime; import org.lance.spark.LanceSparkReadOptions; +import org.lance.spark.internal.LanceFragmentScanner; +import org.lance.spark.utils.Optional; + +import org.apache.spark.sql.types.StructType; import java.io.Serializable; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -27,21 +36,71 @@ public class LanceSplit implements Serializable { private static final long serialVersionUID = 2983749283749283749L; private final List fragments; + private final FilteredReadPlan filteredReadPlan; public LanceSplit(List fragments) { this.fragments = fragments; + this.filteredReadPlan = null; + } + + public LanceSplit(FilteredReadPlan filteredReadPlan) { + this.fragments = new ArrayList<>(filteredReadPlan.getFragmentRanges().keySet()); + this.filteredReadPlan = filteredReadPlan; } public List getFragments() { return fragments; } - public static List generateLanceSplits(LanceSparkReadOptions readOptions) { + public Optional getFilteredReadPlan() { + return Optional.ofNullable(filteredReadPlan); + } + + public static List generateLanceSplits( + LanceSparkReadOptions readOptions, + StructType schema, + Optional whereCondition, + Optional limit, + Optional offset, + Optional> topNSortOrders) { try (Dataset dataset = openDataset(readOptions)) { - return dataset.getFragments().stream() - .map(Fragment::getId) - .map(id -> new LanceSplit(Collections.singletonList(id))) - .collect(Collectors.toList()); + ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); + scanOptionsBuilder.columns(LanceFragmentScanner.getColumnNames(schema)); + if (whereCondition.isPresent()) { + scanOptionsBuilder.filter(whereCondition.get()); + } + scanOptionsBuilder.batchSize(readOptions.getBatchSize()); + scanOptionsBuilder.withRowId(LanceFragmentScanner.getWithRowId(schema)); + scanOptionsBuilder.withRowAddress(LanceFragmentScanner.getWithRowAddress(schema)); + if (readOptions.getNearest() != null) { + scanOptionsBuilder.nearest(readOptions.getNearest()); + } + if (limit.isPresent()) { + scanOptionsBuilder.limit(limit.get()); + } + if (offset.isPresent()) { + scanOptionsBuilder.offset(offset.get()); + } + if (topNSortOrders.isPresent()) { + scanOptionsBuilder.setColumnOrderings(topNSortOrders.get()); + } + + try (LanceScanner scanner = dataset.newScan(scanOptionsBuilder.build())) { + Splits splits = scanner.planSplits(null); + if (splits.getFilteredReadPlans().isPresent()) { + return splits.getFilteredReadPlans().get().stream() + .map(LanceSplit::new) + .collect(Collectors.toList()); + } else if (splits.getFragments().isPresent()) { + return splits.getFragments().get().stream() + .map(id -> new LanceSplit(Collections.singletonList(id))) + .collect(Collectors.toList()); + } else { + return Collections.emptyList(); + } + } catch (Exception e) { + throw new RuntimeException("Failed to plan splits", e); + } } }