From 8b3405f45432cbcf8c3185ca32be1d70269c285c Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Thu, 29 Aug 2024 12:58:59 +0900 Subject: [PATCH] Fix failure when rewriting distance functions in PostgreSQL --- .../rule/RewriteVectorDistanceFunction.java | 5 +++-- .../postgresql/TestPostgreSqlVectorType.java | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java index e6001793c2e21..91cde9783fc26 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java @@ -124,12 +124,13 @@ public static Optional rewrite(ConnectorExpression expr if (expression instanceof Call call && call.getFunctionName().equals(CAST_FUNCTION_NAME)) { ConnectorExpression argument = getOnlyElement(call.getArguments()); if (argument instanceof Variable variable) { - JdbcTypeHandle typeHandle = ((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle(); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(variable.getName()); + JdbcTypeHandle typeHandle = columnHandle.getJdbcTypeHandle(); // TODO type.equals("vector") should be improved to support pushdown on vector type which is installed in other schemas if (!typeHandle.jdbcTypeName().map(type -> type.equals("vector")).orElse(false)) { return Optional.empty(); } - return Optional.of(new ParameterizedExpression(quoted(variable.getName()), ImmutableList.of())); + return Optional.of(new ParameterizedExpression(quoted(columnHandle.getColumnName()), ImmutableList.of())); } return Optional.empty(); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java index 38e96900f8963..3b0b21ef57788 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java @@ -20,6 +20,7 @@ import io.trino.testing.QueryRunner; import io.trino.testing.sql.TestTable; import io.trino.testing.sql.TestView; +import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import java.util.stream.IntStream; @@ -378,4 +379,18 @@ void testPgVectorUnsupportedCosineDistance() .hasMessageContaining("invalid input"); } } + + @RepeatedTest(10) // Regression test for https://github.com/trinodb/trino/issues/23152 + void testDuplicateColumnWithUnion() + { + try (TestTable table = new TestTable(postgreSqlServer::execute, "test_union", "(id int, v vector(3))")) { + postgreSqlServer.execute("INSERT INTO " + table.getName() + " VALUES (1, '[1,2,3]'), (2, '[4,5,6]')"); + + assertThat(query("" + + "SELECT id FROM " + table.getName() + + " UNION ALL " + + "(SELECT id FROM " + table.getName() + " ORDER BY cosine_distance(v, ARRAY[4,5,6]) LIMIT 1)")) + .matches("VALUES 1, 2, 2"); + } + } }