Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
import org.lance.compaction.CompactionTask;
import org.lance.compaction.RewriteResult;
import org.lance.fragment.FragmentMergeResult;
import org.lance.index.Index;
import org.lance.index.IndexType;
import org.lance.operation.Merge;
import org.lance.operation.Update;
import org.lance.schema.LanceField;
import org.lance.schema.LanceSchema;
import org.lance.spark.LanceConfig;
import org.lance.spark.SparkOptions;
import org.lance.spark.read.LanceInputPartition;
Expand All @@ -45,7 +49,10 @@
import org.apache.spark.sql.util.LanceArrowUtils;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class LanceDatasetAdapter {
Expand Down Expand Up @@ -293,4 +300,95 @@ public static void dropDataset(LanceConfig config) {
ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
Dataset.drop(uri, options.getStorageOptions());
}

/**
* Get all indexes from the dataset.
*
* @param dataset the opened Dataset
* @return list of Index objects with full metadata
*/
public static List<Index> getIndexes(Dataset dataset) {
return dataset.getIndexes();
}

/**
* Get indexes that cover the specified column.
*
* @param dataset the opened Dataset
* @param columnName the name of the column
* @return list of indexes covering the column
*/
public static List<Index> getIndexesForColumn(Dataset dataset, String columnName) {
LanceSchema lanceSchema = dataset.getLanceSchema();
Integer fieldId = findFieldIdByName(lanceSchema, columnName);
if (fieldId == null) {
return new ArrayList<>();
}

List<Index> allIndexes = dataset.getIndexes();
return allIndexes.stream()
.filter(index -> index.fields().contains(fieldId))
.collect(Collectors.toList());
}

private static Integer findFieldIdByName(LanceSchema schema, String columnName) {
for (LanceField field : schema.fields()) {
if (field.getName().equals(columnName)) {
return field.getId();
}
}
return null;
}

/**
* Check if an index is an exact index (BTREE or BITMAP). Exact indexes can be used for precise
* counting without scanning.
*
* @param index the index to check
* @return true if the index is an exact index
*/
public static boolean isExactIndex(Index index) {
IndexType type = index.indexType();
return type == IndexType.BTREE || type == IndexType.BITMAP;
}

/**
* Get fragment IDs not covered by an index.
*
* @param dataset the opened Dataset
* @param index the index to check coverage for
* @return list of fragment IDs not covered by the index
*/
public static List<Integer> getUnindexedFragments(Dataset dataset, Index index) {
Set<Integer> allFragments = new HashSet<>(getFragmentIds(dataset));
index.fragments().ifPresent(allFragments::removeAll);
return new ArrayList<>(allFragments);
}

/**
* Get fragment IDs covered by an index.
*
* @param index the index
* @return list of fragment IDs covered by the index, or empty if unknown
*/
public static List<Integer> getIndexedFragments(Index index) {
return index.fragments().orElse(new ArrayList<>());
}

/**
* Count rows using a specific index.
*
* @param dataset the opened Dataset
* @param indexName the name of the index
* @param filter the filter expression
* @param fragmentIds optional list of fragment IDs to restrict counting to
* @return the count of matching rows
*/
public static long countIndexedRows(
Dataset dataset,
String indexName,
String filter,
java.util.Optional<List<Integer>> fragmentIds) {
return dataset.countIndexedRows(indexName, filter, fragmentIds);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class FilterPushDown {
Expand Down Expand Up @@ -192,4 +194,49 @@ private static String compileValue(Object value) {
return value.toString();
}
}

/**
* Extract column names from filters. Returns columns that are directly compared (EqualTo,
* GreaterThan, etc.).
*
* @param filters the filters to extract columns from
* @return set of column names used in the filters
*/
public static Set<String> extractFilterColumns(Filter[] filters) {
Set<String> columns = new HashSet<>();
for (Filter filter : filters) {
extractColumnsFromFilter(filter, columns);
}
return columns;
}

private static void extractColumnsFromFilter(Filter filter, Set<String> columns) {
if (filter instanceof EqualTo) {
columns.add(((EqualTo) filter).attribute());
} else if (filter instanceof GreaterThan) {
columns.add(((GreaterThan) filter).attribute());
} else if (filter instanceof GreaterThanOrEqual) {
columns.add(((GreaterThanOrEqual) filter).attribute());
} else if (filter instanceof LessThan) {
columns.add(((LessThan) filter).attribute());
} else if (filter instanceof LessThanOrEqual) {
columns.add(((LessThanOrEqual) filter).attribute());
} else if (filter instanceof In) {
columns.add(((In) filter).attribute());
} else if (filter instanceof IsNull) {
columns.add(((IsNull) filter).attribute());
} else if (filter instanceof IsNotNull) {
columns.add(((IsNotNull) filter).attribute());
} else if (filter instanceof And) {
And f = (And) filter;
extractColumnsFromFilter(f.left(), columns);
extractColumnsFromFilter(f.right(), columns);
} else if (filter instanceof Or) {
Or f = (Or) filter;
extractColumnsFromFilter(f.left(), columns);
extractColumnsFromFilter(f.right(), columns);
} else if (filter instanceof Not) {
extractColumnsFromFilter(((Not) filter).child(), columns);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.lance.spark.read;

import org.lance.Dataset;
import org.lance.ReadOptions;
import org.lance.spark.SparkOptions;
import org.lance.spark.internal.LanceDatasetAdapter;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.spark.sql.connector.read.PartitionReader;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.LanceArrowUtils;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.vectorized.LanceArrowColumnVector;

import java.io.IOException;
import java.util.Optional;

/**
* Partition reader that counts rows using direct index queries. All indexed fragments are processed
* in a single reader using the scalar index.
*/
public class LanceIndexedCountPartitionReader implements PartitionReader<ColumnarBatch> {
private final LanceSplitCountScan.LanceIndexedCountPartition partition;
private final BufferAllocator allocator;
private boolean finished = false;
private ColumnarBatch currentBatch;

public LanceIndexedCountPartitionReader(
LanceSplitCountScan.LanceIndexedCountPartition partition) {
this.partition = partition;
this.allocator = LanceDatasetAdapter.allocator;
}

@Override
public boolean next() throws IOException {
if (!finished) {
finished = true;
return true;
}
return false;
}

private long computeIndexedCount() {
String uri = partition.getConfig().getDatasetUri();
ReadOptions options = SparkOptions.genReadOptionFromConfig(partition.getConfig());

try (Dataset dataset = Dataset.open(allocator, uri, options)) {
String filter = partition.getWhereCondition().orElse("");
Optional<java.util.List<Integer>> fragmentIds = Optional.of(partition.getFragmentIds());

return LanceDatasetAdapter.countIndexedRows(
dataset, partition.getIndexName(), filter, fragmentIds);
}
}

private ColumnarBatch createCountResultBatch(long count) {
StructType countSchema = new StructType().add("count", DataTypes.LongType);
VectorSchemaRoot root =
VectorSchemaRoot.create(
LanceArrowUtils.toArrowSchema(countSchema, "UTC", false, false), allocator);

root.allocateNew();
BigIntVector countVector = (BigIntVector) root.getVector("count");
countVector.setSafe(0, count);
root.setRowCount(1);

LanceArrowColumnVector[] columns =
root.getFieldVectors().stream()
.map(LanceArrowColumnVector::new)
.toArray(LanceArrowColumnVector[]::new);

return new ColumnarBatch(columns, 1);
}

@Override
public ColumnarBatch get() {
long count = computeIndexedCount();
currentBatch = createCountResultBatch(count);
return currentBatch;
}

@Override
public void close() throws IOException {
if (currentBatch != null) {
currentBatch.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package org.lance.spark.read;

import org.lance.Dataset;
import org.lance.index.Index;
import org.lance.ipc.ColumnOrdering;
import org.lance.spark.LanceConfig;
import org.lance.spark.SparkOptions;
Expand Down Expand Up @@ -43,7 +44,10 @@
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class LanceScanBuilder
implements SupportsPushDownRequiredColumns,
Expand All @@ -61,6 +65,7 @@ public class LanceScanBuilder
private Optional<List<ColumnOrdering>> topNSortOrders = Optional.empty();
private Optional<Aggregation> pushedAggregation = Optional.empty();
private LanceLocalScan localScan = null;
private LanceSplitCountScan splitCountScan = null;

// Lazily opened dataset for reuse during scan building
private Dataset lazyDataset = null;
Expand Down Expand Up @@ -98,6 +103,10 @@ public Scan build() {
if (localScan != null) {
return localScan;
}
// Return SplitCountScan if we have partially indexed fragments
if (splitCountScan != null) {
return splitCountScan;
}
Optional<String> whereCondition = FilterPushDown.compileFiltersToSqlWhereClause(pushedFilters);
return new LanceScan(
schema, config, whereCondition, limit, offset, topNSortOrders, pushedAggregation);
Expand Down Expand Up @@ -204,7 +213,58 @@ public boolean pushAggregation(Aggregation aggregation) {
this.localScan = new LanceLocalScan(countSchema, rows, config.getDatasetUri());
return true;
}
} else {
// Check for indexed column optimization when filters are present
Set<String> filterColumns = FilterPushDown.extractFilterColumns(pushedFilters);
Dataset dataset = getOrOpenDataset();
List<Integer> allFragments = LanceDatasetAdapter.getFragmentIds(dataset);

for (String column : filterColumns) {
List<Index> indexes = LanceDatasetAdapter.getIndexesForColumn(dataset, column);
for (Index index : indexes) {
if (LanceDatasetAdapter.isExactIndex(index)) {
List<Integer> indexedFragments = LanceDatasetAdapter.getIndexedFragments(index);

if (!indexedFragments.isEmpty()) {
Set<Integer> indexedSet = new HashSet<>(indexedFragments);
List<Integer> unindexedFragments =
allFragments.stream()
.filter(f -> !indexedSet.contains(f))
.collect(Collectors.toList());

Optional<String> whereCondition =
FilterPushDown.compileFiltersToSqlWhereClause(pushedFilters);

if (unindexedFragments.isEmpty()) {
// All fragments indexed - use local scan with index count
long count =
LanceDatasetAdapter.countIndexedRows(
dataset,
index.name(),
whereCondition.orElse(""),
java.util.Optional.empty());
StructType countSchema = new StructType().add("count", DataTypes.LongType);
InternalRow[] rows = new InternalRow[1];
rows[0] = new GenericInternalRow(new Object[] {count});
this.localScan = new LanceLocalScan(countSchema, rows, config.getDatasetUri());
return true;
} else {
// Mixed - use split scan
this.splitCountScan =
new LanceSplitCountScan(
new ArrayList<>(indexedSet),
unindexedFragments,
index.name(),
whereCondition,
config);
return true;
}
}
}
}
}
}

// Fall back to scan-based count (with filters or metadata unavailable)
this.pushedAggregation = Optional.of(aggregation);
return true;
Expand Down
Loading
Loading