diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index 122460f54298..177ac65054d3 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -66,6 +66,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -78,9 +79,12 @@ import static io.trino.spi.HostAddress.fromParts; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; +import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -508,6 +512,18 @@ private static Optional translateValue(Object trinoNativeValue, Type typ return Optional.of(trinoNativeValue); } + if (type == DATE) { + return Optional.of(new Date(TimeUnit.DAYS.toMillis((Long) trinoNativeValue))); + } + + if (type == TIMESTAMP_MILLIS) { + return Optional.of(new Date(TimeUnit.MILLISECONDS.convert((Long) trinoNativeValue, TimeUnit.MICROSECONDS))); + } + + if (type == TIMESTAMP_TZ_MILLIS) { + return Optional.of(new Date(unpackMillisUtc((long) trinoNativeValue))); + } + if (type instanceof ObjectIdType) { return Optional.of(new ObjectId(((Slice) trinoNativeValue).getBytes())); } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorTest.java index 783ce29534fd..83f770f89918 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorTest.java @@ -18,6 +18,7 @@ import com.mongodb.DBRef; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; +import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.testing.BaseConnectorTest; import io.trino.testing.MaterializedResult; @@ -589,6 +590,41 @@ protected void verifyTableNameLengthFailurePermissible(Throwable e) assertThat(e).hasMessageMatching(".*fully qualified namespace .* is too long.*"); } + @Test + public void testPredicatePushdown() + { + // TODO test that that predicate is actually pushed down (here we test only correctness) + // varchar equality + assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name = 'ROMANIA'")) + .matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))") + .isNotFullyPushedDown(FilterNode.class); + + // varchar range + assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name BETWEEN 'POLAND' AND 'RPA'")) + .matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))") + .isNotFullyPushedDown(FilterNode.class); + + // varchar different case + assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name = 'romania'")) + .returnsEmptyResult() + .isNotFullyPushedDown(FilterNode.class); + + // bigint equality + assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE nationkey = 19")) + .matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))") + .isNotFullyPushedDown(FilterNode.class); + + // bigint range, with decimal to bigint simplification + assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE nationkey BETWEEN 18.5 AND 19.5")) + .matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))") + .isNotFullyPushedDown(FilterNode.class); + + // date equality + assertThat(query("SELECT orderkey FROM orders WHERE orderdate = DATE '1992-09-29'")) + .matches("VALUES BIGINT '1250', 34406, 38436, 57570") + .isNotFullyPushedDown(FilterNode.class); + } + private void assertOneNotNullResult(String query) { MaterializedResult results = getQueryRunner().execute(getSession(), query).toTestTypes(); diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTypeMapping.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTypeMapping.java index e90073461967..6602d082e276 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTypeMapping.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTypeMapping.java @@ -251,6 +251,9 @@ public void testTime() .addRoundTrip("time", "TIME '23:59:59.9'", createTimeType(3), "TIME '23:59:59.900'") .addRoundTrip("time", "TIME '23:59:59.99'", createTimeType(3), "TIME '23:59:59.990'") .addRoundTrip("time", "TIME '23:59:59.999'", createTimeType(3), "TIME '23:59:59.999'") + .addRoundTrip("time", "TIME '23:59:59.9999'", createTimeType(4), "TIME '23:59:59.9999'") + .addRoundTrip("time", "TIME '23:59:59.99999'", createTimeType(5), "TIME '23:59:59.99999'") + .addRoundTrip("time", "TIME '23:59:59.999999'", createTimeType(6), "TIME '23:59:59.999999'") .execute(getQueryRunner(), trinoCreateAsSelect("test_time")) .execute(getQueryRunner(), trinoCreateAndInsert("test_time")); }