Skip to content

Commit c69402e

Browse files
committed
Add Top-N pushdown support for the MongoDB connector
Add Top-N (ORDER BY + LIMIT) pushdown support for the MongoDB connector. MongoDB considers null values to be less than any other value. We can handle sort items with SortOrder.ASC_NULLS_LAST or SortOrder.DESC_NULLS_FIRST with a MongoDB aggregation pipeline, where we add computed fields to sort correctly. We add sort fields to the projection, so that they are considered by MongoDB. We add Top-N pushdown support only for simple nested queries (that work on the same collection).
1 parent e0aaffb commit c69402e

File tree

5 files changed

+176
-6
lines changed

5 files changed

+176
-6
lines changed

plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@
5050
import io.trino.spi.connector.SaveMode;
5151
import io.trino.spi.connector.SchemaTableName;
5252
import io.trino.spi.connector.SchemaTablePrefix;
53+
import io.trino.spi.connector.SortItem;
54+
import io.trino.spi.connector.SortOrder;
5355
import io.trino.spi.connector.SortingProperty;
5456
import io.trino.spi.connector.TableFunctionApplicationResult;
5557
import io.trino.spi.connector.TableNotFoundException;
58+
import io.trino.spi.connector.TopNApplicationResult;
5659
import io.trino.spi.expression.ConnectorExpression;
5760
import io.trino.spi.expression.FieldDereference;
5861
import io.trino.spi.expression.Variable;
@@ -72,6 +75,8 @@
7275
import org.bson.Document;
7376

7477
import java.io.IOException;
78+
import java.util.ArrayList;
79+
import java.util.Arrays;
7580
import java.util.Collection;
7681
import java.util.HashMap;
7782
import java.util.Iterator;
@@ -132,6 +137,9 @@ public class MongoMetadata
132137

133138
private static final int MAX_QUALIFIED_IDENTIFIER_BYTE_LENGTH = 120;
134139

140+
public static final int MONGO_SORT_ASC = 1;
141+
private static final int MONGO_SORT_DESC = -1;
142+
135143
private final MongoSession mongoSession;
136144

137145
private final AtomicReference<Runnable> rollbackAction = new AtomicReference<>();
@@ -622,11 +630,75 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
622630
handle.filter(),
623631
handle.constraint(),
624632
handle.projectedColumns(),
633+
handle.sort(),
625634
OptionalInt.of(toIntExact(limit))),
626635
true,
627636
false));
628637
}
629638

639+
@Override
640+
public Optional<TopNApplicationResult<ConnectorTableHandle>> applyTopN(
641+
ConnectorSession session,
642+
ConnectorTableHandle table,
643+
long topNCount,
644+
List<SortItem> sortItems,
645+
Map<String, ColumnHandle> assignments)
646+
{
647+
MongoTableHandle handle = (MongoTableHandle) table;
648+
649+
// MongoDB doesn't support topN number greater than integer max
650+
if (topNCount > Integer.MAX_VALUE) {
651+
return Optional.empty();
652+
}
653+
654+
for (Map.Entry<String, ColumnHandle> columnHandleEntry : assignments.entrySet()) {
655+
MongoColumnHandle columnHandle = (MongoColumnHandle) columnHandleEntry.getValue();
656+
if (!columnHandleEntry.getKey().equals(columnHandle.baseName())) {
657+
// We don't support complex nested queries
658+
return Optional.empty();
659+
}
660+
}
661+
662+
// Convert Trino sort items to a BSON sort document for MongoDB.
663+
Document sortDocument = new Document();
664+
Document sortNullFieldsDocument = null;
665+
for (SortItem sortItem : sortItems) {
666+
String columnName = sortItem.getName();
667+
int direction = (sortItem.getSortOrder() == SortOrder.ASC_NULLS_FIRST || sortItem.getSortOrder() == SortOrder.ASC_NULLS_LAST) ? MONGO_SORT_ASC : MONGO_SORT_DESC;
668+
669+
// MongoDB considers null values to be less than any other value.
670+
// When we have sort items with SortOrder.ASC_NULLS_LAST or SortOrder.DESC_NULLS_FIRST,
671+
// we need to add computed fields to sort correctly.
672+
if (sortItem.getSortOrder() == SortOrder.ASC_NULLS_LAST || sortItem.getSortOrder() == SortOrder.DESC_NULLS_FIRST) {
673+
String sortColumnName = "_sortNulls_" + columnName;
674+
Document condition = new Document();
675+
condition.append("$cond", ImmutableList.of(new Document("$eq", Arrays.asList("$" + columnName, null)), 1, 0));
676+
if (sortNullFieldsDocument == null) {
677+
sortNullFieldsDocument = new Document();
678+
}
679+
sortNullFieldsDocument.append(sortColumnName, condition);
680+
sortDocument.append(sortColumnName, direction);
681+
}
682+
683+
sortDocument.append(columnName, direction);
684+
}
685+
List<MongoTableSort> tableSortList = handle.sort().orElse(new ArrayList<>());
686+
MongoTableSort tableSort = new MongoTableSort(sortDocument, Optional.ofNullable(sortNullFieldsDocument), toIntExact(topNCount));
687+
tableSortList.add(tableSort);
688+
689+
return Optional.of(new TopNApplicationResult<>(
690+
new MongoTableHandle(
691+
handle.schemaTableName(),
692+
handle.remoteTableName(),
693+
handle.filter(),
694+
handle.constraint(),
695+
handle.projectedColumns(),
696+
Optional.of(tableSortList),
697+
OptionalInt.empty()),
698+
true,
699+
false));
700+
}
701+
630702
@Override
631703
public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint)
632704
{
@@ -670,6 +742,7 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
670742
handle.filter(),
671743
newDomain,
672744
handle.projectedColumns(),
745+
handle.sort(),
673746
handle.limit());
674747

675748
return Optional.of(new ConstraintApplicationResult<>(handle, remainingFilter, constraint.getExpression(), false));

plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import com.mongodb.DBRef;
2828
import com.mongodb.MongoCommandException;
2929
import com.mongodb.MongoNamespace;
30+
import com.mongodb.client.AggregateIterable;
3031
import com.mongodb.client.FindIterable;
3132
import com.mongodb.client.MongoClient;
3233
import com.mongodb.client.MongoCollection;
@@ -528,15 +529,57 @@ public MongoCursor<Document> execute(MongoTableHandle tableHandle, List<MongoCol
528529

529530
MongoCollection<Document> collection = getCollection(tableHandle.remoteTableName());
530531
Document filter = buildFilter(tableHandle);
531-
FindIterable<Document> iterable = collection.find(filter).projection(projection).collation(SIMPLE_COLLATION);
532-
tableHandle.limit().ifPresent(iterable::limit);
532+
533533
log.debug("Find documents: collection: %s, filter: %s, projection: %s", tableHandle.schemaTableName(), filter, projection);
534534

535-
if (cursorBatchSize != 0) {
536-
iterable.batchSize(cursorBatchSize);
535+
boolean useFind = tableHandle.sort().isEmpty() ||
536+
(tableHandle.sort().get().size() == 1 && tableHandle.sort().get().getFirst().sortNullFields().isEmpty());
537+
538+
if (useFind) {
539+
FindIterable<Document> iterable = collection.find(filter).projection(projection).collation(SIMPLE_COLLATION);
540+
tableHandle.sort().ifPresent(sortList -> {
541+
iterable.sort(sortList.getFirst().sort());
542+
iterable.limit(toIntExact(sortList.getFirst().limit()));
543+
});
544+
tableHandle.limit().ifPresent(iterable::limit);
545+
546+
if (cursorBatchSize != 0) {
547+
iterable.batchSize(cursorBatchSize);
548+
}
549+
550+
return iterable.iterator();
537551
}
552+
else {
553+
// MongoDB considers null values to be less than any other value.
554+
// We can handle sort items with SortOrder.ASC_NULLS_LAST or SortOrder.DESC_NULLS_FIRST
555+
// with an aggregation pipeline.
556+
List<MongoTableSort> tableSortList = tableHandle.sort().get();
557+
for (MongoTableSort tableSort : tableSortList) {
558+
for (String sortField : tableSort.sort().keySet()) {
559+
// Sorting on the field does not work unless we add it to the projection
560+
if (!projection.containsKey(sortField)) {
561+
projection.append(sortField, MongoMetadata.MONGO_SORT_ASC);
562+
}
563+
}
564+
}
565+
566+
List<Document> aggregateList = new ArrayList<>();
567+
aggregateList.add(new Document("$match", filter));
568+
for (MongoTableSort tableSort : tableSortList) {
569+
tableSort.sortNullFields().ifPresent(sortNullFields -> new Document("$addFields", sortNullFields));
570+
aggregateList.add(new Document("$sort", tableSort.sort()));
571+
aggregateList.add(new Document("$limit", tableSort.limit()));
572+
}
573+
tableHandle.limit().ifPresent(limit -> aggregateList.add(new Document("$limit", limit)));
574+
aggregateList.add(new Document("$project", projection));
575+
AggregateIterable<Document> aggregateIterable = collection.aggregate(aggregateList).collation(SIMPLE_COLLATION);
538576

539-
return iterable.iterator();
577+
if (cursorBatchSize != 0) {
578+
aggregateIterable.batchSize(cursorBatchSize);
579+
}
580+
581+
return aggregateIterable.iterator();
582+
}
540583
}
541584

542585
@VisibleForTesting

plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io.trino.spi.connector.SchemaTableName;
2020
import io.trino.spi.predicate.TupleDomain;
2121

22+
import java.util.List;
2223
import java.util.Optional;
2324
import java.util.OptionalInt;
2425
import java.util.Set;
@@ -32,12 +33,13 @@ public record MongoTableHandle(
3233
Optional<String> filter,
3334
TupleDomain<ColumnHandle> constraint,
3435
Set<MongoColumnHandle> projectedColumns,
36+
Optional<List<MongoTableSort>> sort,
3537
OptionalInt limit)
3638
implements ConnectorTableHandle
3739
{
3840
public MongoTableHandle(SchemaTableName schemaTableName, RemoteTableName remoteTableName, Optional<String> filter)
3941
{
40-
this(schemaTableName, remoteTableName, filter, TupleDomain.all(), ImmutableSet.of(), OptionalInt.empty());
42+
this(schemaTableName, remoteTableName, filter, TupleDomain.all(), ImmutableSet.of(), Optional.empty(), OptionalInt.empty());
4143
}
4244

4345
public MongoTableHandle
@@ -47,6 +49,7 @@ public MongoTableHandle(SchemaTableName schemaTableName, RemoteTableName remoteT
4749
requireNonNull(filter, "filter is null");
4850
requireNonNull(constraint, "constraint is null");
4951
projectedColumns = ImmutableSet.copyOf(requireNonNull(projectedColumns, "projectedColumns is null"));
52+
requireNonNull(sort, "sort is null");
5053
requireNonNull(limit, "limit is null");
5154
}
5255

@@ -68,6 +71,7 @@ else if (!constraint.isAll()) {
6871
if (!projectedColumns.isEmpty()) {
6972
builder.append(" columns=").append(projectedColumns);
7073
}
74+
sort.ifPresent(value -> builder.append(" TopNPartial").append(value));
7175
limit.ifPresent(value -> builder.append(" limit=").append(value));
7276
return builder.toString();
7377
}
@@ -80,6 +84,7 @@ public MongoTableHandle withProjectedColumns(Set<MongoColumnHandle> projectedCol
8084
filter,
8185
constraint,
8286
projectedColumns,
87+
sort,
8388
limit);
8489
}
8590

@@ -91,6 +96,7 @@ public MongoTableHandle withConstraint(TupleDomain<ColumnHandle> constraint)
9196
filter,
9297
constraint,
9398
projectedColumns,
99+
sort,
94100
limit);
95101
}
96102
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.plugin.mongodb;
15+
16+
import com.google.common.collect.ImmutableList;
17+
import org.bson.Document;
18+
19+
import java.util.List;
20+
import java.util.Optional;
21+
22+
import static java.util.Objects.requireNonNull;
23+
24+
public record MongoTableSort(Document sort, Optional<Document> sortNullFields, int limit)
25+
{
26+
public MongoTableSort
27+
{
28+
requireNonNull(sort, "sort is null");
29+
requireNonNull(sortNullFields, "sortNullFields is null");
30+
}
31+
32+
@Override
33+
public String toString()
34+
{
35+
StringBuilder builder = new StringBuilder();
36+
builder.append("count = ").append(limit);
37+
List<String> sortEntries = sort.entrySet().stream()
38+
.map(entry -> entry.getKey() + " " + ("1".equals(entry.getValue()) ? "ASC" : "DESC"))
39+
.toList();
40+
builder.append(", orderBy = ").append(sortEntries);
41+
if (sortNullFields.isPresent()) {
42+
List<String> nullFields = ImmutableList.copyOf(sortNullFields.get().keySet());
43+
builder.append(", nullFields=").append(nullFields);
44+
}
45+
return builder.toString();
46+
}
47+
}

plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ public void testRoundTripWithProjectedColumns()
125125
Optional.empty(),
126126
TupleDomain.all(),
127127
projectedColumns,
128+
Optional.empty(),
128129
OptionalInt.empty());
129130

130131
String json = codec.toJson(expected);

0 commit comments

Comments
 (0)