diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java index 4124c0193435..d4cf422c71e2 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java @@ -50,9 +50,12 @@ import io.trino.spi.connector.SaveMode; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; +import io.trino.spi.connector.SortItem; +import io.trino.spi.connector.SortOrder; import io.trino.spi.connector.SortingProperty; import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.connector.TableNotFoundException; +import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.Variable; @@ -72,6 +75,8 @@ import org.bson.Document; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; @@ -132,6 +137,9 @@ public class MongoMetadata private static final int MAX_QUALIFIED_IDENTIFIER_BYTE_LENGTH = 120; + public static final int MONGO_SORT_ASC = 1; + private static final int MONGO_SORT_DESC = -1; + private final MongoSession mongoSession; private final AtomicReference rollbackAction = new AtomicReference<>(); @@ -622,11 +630,75 @@ public Optional> applyLimit(Connect handle.filter(), handle.constraint(), handle.projectedColumns(), + handle.sort(), OptionalInt.of(toIntExact(limit))), true, false)); } + @Override + public Optional> applyTopN( + ConnectorSession session, + ConnectorTableHandle table, + long topNCount, + List sortItems, + Map assignments) + { + MongoTableHandle handle = (MongoTableHandle) table; + + // MongoDB doesn't support topN number greater than integer max + if (topNCount > Integer.MAX_VALUE) { + return Optional.empty(); + } + + for (Map.Entry columnHandleEntry : assignments.entrySet()) { + MongoColumnHandle columnHandle = (MongoColumnHandle) columnHandleEntry.getValue(); + if (!columnHandleEntry.getKey().equals(columnHandle.baseName())) { + // We don't support complex nested queries + return Optional.empty(); + } + } + + // Convert Trino sort items to a BSON sort document for MongoDB. + Document sortDocument = new Document(); + Document sortNullFieldsDocument = null; + for (SortItem sortItem : sortItems) { + String columnName = sortItem.getName(); + int direction = (sortItem.getSortOrder() == SortOrder.ASC_NULLS_FIRST || sortItem.getSortOrder() == SortOrder.ASC_NULLS_LAST) ? MONGO_SORT_ASC : MONGO_SORT_DESC; + + // MongoDB considers null values to be less than any other value. + // When we have sort items with SortOrder.ASC_NULLS_LAST or SortOrder.DESC_NULLS_FIRST, + // we need to add computed fields to sort correctly. + if (sortItem.getSortOrder() == SortOrder.ASC_NULLS_LAST || sortItem.getSortOrder() == SortOrder.DESC_NULLS_FIRST) { + String sortColumnName = "_sortNulls_" + columnName; + Document condition = new Document(); + condition.append("$cond", ImmutableList.of(new Document("$eq", Arrays.asList("$" + columnName, null)), 1, 0)); + if (sortNullFieldsDocument == null) { + sortNullFieldsDocument = new Document(); + } + sortNullFieldsDocument.append(sortColumnName, condition); + sortDocument.append(sortColumnName, direction); + } + + sortDocument.append(columnName, direction); + } + List tableSortList = handle.sort().orElse(new ArrayList<>()); + MongoTableSort tableSort = new MongoTableSort(sortDocument, Optional.ofNullable(sortNullFieldsDocument), toIntExact(topNCount)); + tableSortList.add(tableSort); + + return Optional.of(new TopNApplicationResult<>( + new MongoTableHandle( + handle.schemaTableName(), + handle.remoteTableName(), + handle.filter(), + handle.constraint(), + handle.projectedColumns(), + Optional.of(tableSortList), + OptionalInt.empty()), + true, + false)); + } + @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { @@ -670,6 +742,7 @@ public Optional> applyFilter(C handle.filter(), newDomain, handle.projectedColumns(), + handle.sort(), handle.limit()); return Optional.of(new ConstraintApplicationResult<>(handle, remainingFilter, constraint.getExpression(), false)); diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index a5ccbe0d3581..f017c3b1d909 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -28,6 +28,7 @@ import com.mongodb.DBRef; import com.mongodb.MongoCommandException; import com.mongodb.MongoNamespace; +import com.mongodb.client.AggregateIterable; import com.mongodb.client.FindIterable; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; @@ -533,15 +534,57 @@ public MongoCursor execute(MongoTableHandle tableHandle, List collection = getCollection(tableHandle.remoteTableName()); Document filter = buildFilter(tableHandle); - FindIterable iterable = collection.find(filter).projection(projection).collation(SIMPLE_COLLATION); - tableHandle.limit().ifPresent(iterable::limit); + log.debug("Find documents: collection: %s, filter: %s, projection: %s", tableHandle.schemaTableName(), filter, projection); - if (cursorBatchSize != 0) { - iterable.batchSize(cursorBatchSize); + boolean useFind = tableHandle.sort().isEmpty() || + (tableHandle.sort().get().size() == 1 && tableHandle.sort().get().getFirst().sortNullFields().isEmpty()); + + if (useFind) { + FindIterable iterable = collection.find(filter).projection(projection).collation(SIMPLE_COLLATION); + tableHandle.sort().ifPresent(sortList -> { + iterable.sort(sortList.getFirst().sort()); + iterable.limit(toIntExact(sortList.getFirst().limit())); + }); + tableHandle.limit().ifPresent(iterable::limit); + + if (cursorBatchSize != 0) { + iterable.batchSize(cursorBatchSize); + } + + return iterable.iterator(); } + else { + // MongoDB considers null values to be less than any other value. + // We can handle sort items with SortOrder.ASC_NULLS_LAST or SortOrder.DESC_NULLS_FIRST + // with an aggregation pipeline. + List tableSortList = tableHandle.sort().get(); + for (MongoTableSort tableSort : tableSortList) { + for (String sortField : tableSort.sort().keySet()) { + // Sorting on the field does not work unless we add it to the projection + if (!projection.containsKey(sortField)) { + projection.append(sortField, MongoMetadata.MONGO_SORT_ASC); + } + } + } + + List aggregateList = new ArrayList<>(); + aggregateList.add(new Document("$match", filter)); + for (MongoTableSort tableSort : tableSortList) { + tableSort.sortNullFields().ifPresent(sortNullFields -> new Document("$addFields", sortNullFields)); + aggregateList.add(new Document("$sort", tableSort.sort())); + aggregateList.add(new Document("$limit", tableSort.limit())); + } + tableHandle.limit().ifPresent(limit -> aggregateList.add(new Document("$limit", limit))); + aggregateList.add(new Document("$project", projection)); + AggregateIterable aggregateIterable = collection.aggregate(aggregateList).collation(SIMPLE_COLLATION); - return iterable.iterator(); + if (cursorBatchSize != 0) { + aggregateIterable.batchSize(cursorBatchSize); + } + + return aggregateIterable.iterator(); + } } @VisibleForTesting diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java index 5319ec26393f..79b6a9a18f34 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java @@ -19,6 +19,7 @@ import io.trino.spi.connector.SchemaTableName; import io.trino.spi.predicate.TupleDomain; +import java.util.List; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -32,12 +33,13 @@ public record MongoTableHandle( Optional filter, TupleDomain constraint, Set projectedColumns, + Optional> sort, OptionalInt limit) implements ConnectorTableHandle { public MongoTableHandle(SchemaTableName schemaTableName, RemoteTableName remoteTableName, Optional filter) { - this(schemaTableName, remoteTableName, filter, TupleDomain.all(), ImmutableSet.of(), OptionalInt.empty()); + this(schemaTableName, remoteTableName, filter, TupleDomain.all(), ImmutableSet.of(), Optional.empty(), OptionalInt.empty()); } public MongoTableHandle @@ -47,6 +49,7 @@ public MongoTableHandle(SchemaTableName schemaTableName, RemoteTableName remoteT requireNonNull(filter, "filter is null"); requireNonNull(constraint, "constraint is null"); projectedColumns = ImmutableSet.copyOf(requireNonNull(projectedColumns, "projectedColumns is null")); + requireNonNull(sort, "sort is null"); requireNonNull(limit, "limit is null"); } @@ -68,6 +71,7 @@ else if (!constraint.isAll()) { if (!projectedColumns.isEmpty()) { builder.append(" columns=").append(projectedColumns); } + sort.ifPresent(value -> builder.append(" TopNPartial").append(value)); limit.ifPresent(value -> builder.append(" limit=").append(value)); return builder.toString(); } @@ -80,6 +84,7 @@ public MongoTableHandle withProjectedColumns(Set projectedCol filter, constraint, projectedColumns, + sort, limit); } @@ -91,6 +96,7 @@ public MongoTableHandle withConstraint(TupleDomain constraint) filter, constraint, projectedColumns, + sort, limit); } } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableSort.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableSort.java new file mode 100644 index 000000000000..b9bc2d9f533c --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableSort.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.google.common.collect.ImmutableList; +import org.bson.Document; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record MongoTableSort(Document sort, Optional sortNullFields, int limit) +{ + public MongoTableSort + { + requireNonNull(sort, "sort is null"); + requireNonNull(sortNullFields, "sortNullFields is null"); + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + builder.append("count = ").append(limit); + List sortEntries = sort.entrySet().stream() + .map(entry -> entry.getKey() + " " + ("1".equals(entry.getValue()) ? "ASC" : "DESC")) + .toList(); + builder.append(", orderBy = ").append(sortEntries); + if (sortNullFields.isPresent()) { + List nullFields = ImmutableList.copyOf(sortNullFields.get().keySet()); + builder.append(", nullFields=").append(nullFields); + } + return builder.toString(); + } +} diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java index f97bf8691710..1acaa0c5d985 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java @@ -112,7 +112,6 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_RENAME_FIELD, SUPPORTS_RENAME_SCHEMA, SUPPORTS_SET_FIELD_TYPE, - SUPPORTS_TOPN_PUSHDOWN, SUPPORTS_TRUNCATE, SUPPORTS_UPDATE -> false; default -> super.hasBehavior(connectorBehavior); diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java index 48ee627dc3ce..4e3956bcc673 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java @@ -125,6 +125,7 @@ public void testRoundTripWithProjectedColumns() Optional.empty(), TupleDomain.all(), projectedColumns, + Optional.empty(), OptionalInt.empty()); String json = codec.toJson(expected);