diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java index 1256c716e09b..9a1fdc5e7829 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java @@ -15,6 +15,7 @@ import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; +import com.mongodb.DBRef; import com.mongodb.client.MongoCursor; import io.airlift.slice.Slice; import io.trino.spi.Page; @@ -317,6 +318,18 @@ else if (isRowType(type)) { output.closeEntry(); return; } + else if (value instanceof DBRef) { + DBRef dbRefValue = (DBRef) value; + BlockBuilder builder = output.beginBlockEntry(); + + checkState(type.getTypeParameters().size() == 3, "DBRef should have 3 fields : %s", type); + appendTo(type.getTypeParameters().get(0), dbRefValue.getDatabaseName(), builder); + appendTo(type.getTypeParameters().get(1), dbRefValue.getCollectionName(), builder); + appendTo(type.getTypeParameters().get(2), dbRefValue.getId(), builder); + + output.closeEntry(); + return; + } else if (value instanceof List) { List listValue = (List) value; BlockBuilder builder = output.beginBlockEntry(); diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index 0c3f4a850a12..dcf3cdb8d2aa 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -24,6 +24,7 @@ import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; import com.google.common.util.concurrent.UncheckedExecutionException; +import com.mongodb.DBRef; import com.mongodb.MongoClient; import com.mongodb.client.FindIterable; import com.mongodb.client.MongoCollection; @@ -77,6 +78,7 @@ import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static java.lang.Math.toIntExact; import static java.lang.String.format; @@ -108,6 +110,10 @@ public class MongoSession private static final String LTE_OP = "$lte"; private static final String IN_OP = "$in"; + private static final String DATABASE_NAME = "databaseName"; + private static final String COLLECTION_NAME = "collectionName"; + private static final String ID = "id"; + private final TypeManager typeManager; private final MongoClient client; @@ -640,6 +646,18 @@ else if (value instanceof Document) { typeSignature = new TypeSignature(StandardTypes.ROW, parameters); } } + else if (value instanceof DBRef) { + List parameters = new ArrayList<>(); + + TypeSignature idFieldType = guessFieldType(((DBRef) value).getId()) + .orElseThrow(() -> new UnsupportedOperationException("Unable to guess $id field type of DBRef from: " + ((DBRef) value).getId())); + + parameters.add(TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature(Optional.of(new RowFieldName(DATABASE_NAME)), VARCHAR.getTypeSignature()))); + parameters.add(TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature(Optional.of(new RowFieldName(COLLECTION_NAME)), VARCHAR.getTypeSignature()))); + parameters.add(TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature(Optional.of(new RowFieldName(ID)), idFieldType))); + + typeSignature = new TypeSignature(StandardTypes.ROW, parameters); + } return Optional.ofNullable(typeSignature); } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoIntegrationSmokeTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoIntegrationSmokeTest.java index 70b998ca7c2c..73ac7a1d8a48 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoIntegrationSmokeTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoIntegrationSmokeTest.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.mongodb.DBRef; import com.mongodb.MongoClient; import com.mongodb.client.MongoCollection; import io.trino.sql.planner.plan.LimitNode; @@ -23,6 +24,7 @@ import io.trino.testing.MaterializedRow; import io.trino.testing.QueryRunner; import org.bson.Document; +import org.bson.types.ObjectId; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; @@ -191,6 +193,25 @@ public void testSkipUnknownTypes() assertQueryReturnsEmptyResult("SHOW COLUMNS FROM test.tmp_guess_schema2"); } + @Test + public void testDBRef() + { + Document document = Document.parse("{\"_id\":ObjectId(\"5126bbf64aed4daf9e2ab771\"),\"col1\":\"foo\"}"); + + ObjectId objectId = new ObjectId("5126bc054aed4daf9e2ab772"); + DBRef dbRef = new DBRef("test", "creators", objectId); + document.append("creator", dbRef); + + client.getDatabase("test").getCollection("test_dbref").insertOne(document); + + assertQuery( + "SELECT creator.databaseName, creator.collectionName, CAST(creator.id AS VARCHAR) FROM test.test_dbref", + "SELECT 'test', 'creators', '5126bc054aed4daf9e2ab772'"); + assertQuery( + "SELECT typeof(creator) FROM test.test_dbref", + "SELECT 'row(databaseName varchar, collectionName varchar, id ObjectId)'"); + } + @Test public void testMaps() {