diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index ed18dda01c07..9c1ec4d13752 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -73,6 +73,7 @@ import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping; +import io.trino.plugin.postgresql.rule.RewriteDotProductFunction; import io.trino.plugin.postgresql.rule.RewriteStringReverseFunction; import io.trino.plugin.postgresql.rule.RewriteVectorDistanceFunction; import io.trino.spi.TrinoException; @@ -348,7 +349,7 @@ public PostgreSqlClient( .add(new RewriteStringReverseFunction()) .add(new RewriteVectorDistanceFunction("euclidean_distance", "<->")) .add(new RewriteVectorDistanceFunction("cosine_distance", "<=>")) - // TODO Rewrite Trino -dot_product to pgvector <#> operator + .add(new RewriteDotProductFunction()) .build()); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteDotProductFunction.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteDotProductFunction.java new file mode 100644 index 000000000000..700ca8b58174 --- /dev/null +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteDotProductFunction.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.postgresql.rule; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.projection.ProjectFunctionRule; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.FunctionName; + +import java.sql.Types; +import java.util.Optional; + +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.postgresql.rule.RewriteVectorDistanceFunction.isArrayTypeWithRealOrDouble; +import static io.trino.spi.type.DoubleType.DOUBLE; + +public final class RewriteDotProductFunction + implements ProjectFunctionRule +{ + private static final Capture CALL = newCapture(); + + private static final Pattern PATTERN = call() + .with(functionName().equalTo(new FunctionName("$negate"))) + .with(type().matching(type -> type == DOUBLE)) + .with(argumentCount().equalTo(1)) + .with(argument(0).matching(expression().capturedAs(CALL).matching(expression -> expression instanceof Call call + && call.getFunctionName().equals(new FunctionName("dot_product")) + && call.getArguments().size() == 2 + && call.getArguments().stream().allMatch(argument -> isArrayTypeWithRealOrDouble(argument.getType()))))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext context) + { + ConnectorExpression call = captures.get(CALL); + + Optional leftExpression = RewriteVectorDistanceFunction.rewrite(call.getChildren().getFirst(), context); + if (leftExpression.isEmpty()) { + return Optional.empty(); + } + + Optional rightExpression = RewriteVectorDistanceFunction.rewrite(call.getChildren().get(1), context); + if (rightExpression.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(new JdbcExpression( + "%s <#> %s".formatted(leftExpression.get().expression(), rightExpression.get().expression()), + ImmutableList.builder() + .addAll(leftExpression.get().parameters()) + .addAll(rightExpression.get().parameters()) + .build(), + new JdbcTypeHandle( + Types.DOUBLE, + Optional.of("double"), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()))); + } +} diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java index 72cc7371eb60..e6001793c2e2 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java @@ -103,7 +103,7 @@ public Optional rewrite(ConnectorExpression projectionExpression Optional.empty()))); } - private static Optional rewrite(ConnectorExpression expression, RewriteContext context) + public static Optional rewrite(ConnectorExpression expression, RewriteContext context) { if (expression instanceof Constant constant) { Type elementType = ((ArrayType) constant.getType()).getElementType(); @@ -140,7 +140,7 @@ private static Optional rewrite(ConnectorExpression exp return Optional.of(translatedArgument.orElseThrow()); } - private static boolean isArrayTypeWithRealOrDouble(Type type) + public static boolean isArrayTypeWithRealOrDouble(Type type) { return type instanceof ArrayType arrayType && (arrayType.getElementType() == REAL || arrayType.getElementType() == DOUBLE); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java index a28fe9214d6f..38e96900f896 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java @@ -255,14 +255,13 @@ void testDotProductCompatibility() TestView view = new TestView(postgreSqlServer::execute, "test_dot_product", "SELECT v <#> '[7,8,9]' FROM " + table.getName())) { postgreSqlServer.execute("INSERT INTO " + table.getName() + " VALUES (1, '[1,2,3]'), (2, '[4,5,6]')"); - // TODO Add support for projection pushdown with dot_product function // The minus sign is needed because <#> returns the negative inner product. Postgres only supports ASC order index scans on operators. assertThat(query("SELECT -dot_product(v, ARRAY[7,8,9]) FROM " + table.getName())) .matches("SELECT * FROM tpch." + view.getName()) - .isNotFullyPushedDown(ProjectNode.class); + .isFullyPushedDown(); assertThat(query("SELECT id FROM " + table.getName() + " ORDER BY -dot_product(v, ARRAY[7,8,9]) LIMIT 1")) - .isNotFullyPushedDown(ProjectNode.class); + .isFullyPushedDown(); } }