diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ProjectFunctionRewriter.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ProjectFunctionRewriter.java index 5c96c81ca833..660e4d132a31 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ProjectFunctionRewriter.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ProjectFunctionRewriter.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.projection.ProjectFunctionRule.RewriteContext; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.expression.ConnectorExpression; import java.util.Iterator; @@ -39,8 +40,9 @@ public ProjectFunctionRewriter(ConnectorExpressionRewriter con this.rules = ImmutableSet.copyOf(requireNonNull(rules, "rules is null")); } - public Optional rewrite(ConnectorSession session, ConnectorExpression projectionExpression, Map assignments) + public Optional rewrite(ConnectorSession session, ConnectorTableHandle handle, ConnectorExpression projectionExpression, Map assignments) { + requireNonNull(handle, "handle is null"); requireNonNull(projectionExpression, "projectionExpression is null"); requireNonNull(assignments, "assignments is null"); @@ -70,7 +72,7 @@ public Optional rewriteExpression(ConnectorExpression expressi Iterator matches = rule.getPattern().match(projectionExpression, context).iterator(); while (matches.hasNext()) { Match match = matches.next(); - Optional rewritten = rule.rewrite(projectionExpression, match.captures(), context); + Optional rewritten = rule.rewrite(handle, projectionExpression, match.captures(), context); if (rewritten.isPresent()) { return rewritten; } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ProjectFunctionRule.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ProjectFunctionRule.java index 45bbe4067f7a..ed714132a51a 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ProjectFunctionRule.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ProjectFunctionRule.java @@ -17,6 +17,7 @@ import io.trino.matching.Pattern; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.expression.ConnectorExpression; import java.util.Map; @@ -34,7 +35,7 @@ default boolean isEnabled(ConnectorSession session) return true; } - Optional rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext context); + Optional rewrite(ConnectorTableHandle handle, ConnectorExpression projectionExpression, Captures captures, RewriteContext context); interface RewriteContext { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index a5595c27b567..44f1a81ecba8 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -240,9 +240,9 @@ public Optional convertPredicate(ConnectorSession sessi } @Override - public Optional convertProjection(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map assignments) { - return delegate.convertProjection(session, expression, assignments); + return delegate.convertProjection(session, handle, expression, assignments); } @Override diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 79ad30eddadd..76847d2b9070 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -387,7 +387,7 @@ private Optional> applyProject Set translatedExpressions = new HashSet<>(); for (ConnectorExpression projection : ImmutableSet.copyOf(projections)) { - RewrittenExpression rewrittenExpression = rewriteExpression(session, nextSyntheticColumnId, projection, assignments, translatedExpressions); + RewrittenExpression rewrittenExpression = rewriteExpression(session, handle, nextSyntheticColumnId, projection, assignments, translatedExpressions); nextSyntheticColumnId = rewrittenExpression.nextSyntheticColumnId(); newVariablesBuilder.putAll(rewrittenExpression.syntheticVariables()); columnExpressionsBuilder.putAll(rewrittenExpression.columnExpressions()); @@ -430,6 +430,7 @@ private Optional> applyProject private RewrittenExpression rewriteExpression( ConnectorSession session, + JdbcTableHandle handle, int nextSyntheticColumnId, ConnectorExpression projection, Map assignments, @@ -453,7 +454,7 @@ private RewrittenExpression rewriteExpression( ImmutableSet.of((JdbcColumnHandle) assignments.get(variable.getName())), ImmutableList.of(new Assignment(variable.getName(), assignments.get(variable.getName()), variable.getType()))); } - Optional convertedExpression = jdbcClient.convertProjection(session, projection, assignments); + Optional convertedExpression = jdbcClient.convertProjection(session, handle, projection, assignments); if (convertedExpression.isPresent()) { String columnName = SYNTHETIC_COLUMN_NAME_PREFIX + nextSyntheticColumnId; JdbcColumnHandle newColumn = JdbcColumnHandle.builder() @@ -481,6 +482,7 @@ private RewrittenExpression rewriteExpression( for (ConnectorExpression child : projection.getChildren()) { RewrittenExpression rewrittenExpression = rewriteExpression( session, + handle, nextSyntheticColumnId, child, assignments, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index 9848a40d74b7..0439c3f037e2 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -163,9 +163,9 @@ public Optional convertPredicate(ConnectorSession sessi } @Override - public Optional convertProjection(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map assignments) { - return delegate().convertProjection(session, expression, assignments); + return delegate().convertProjection(session, handle, expression, assignments); } @Override diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index e3b9ababf698..426065b3c93a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -100,7 +100,7 @@ default Optional convertPredicate(ConnectorSession sess return Optional.empty(); } - default Optional convertProjection(ConnectorSession session, ConnectorExpression expression, Map assignments) + default Optional convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map assignments) { return Optional.empty(); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index 3cc2b9c96bb1..95f350c78f80 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -185,9 +185,9 @@ public Optional convertPredicate(ConnectorSession sessi } @Override - public Optional convertProjection(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map assignments) { - return stats.getConvertProjection().wrap(() -> delegate().convertProjection(session, expression, assignments)); + return stats.getConvertProjection().wrap(() -> delegate().convertProjection(session, handle, expression, assignments)); } @Override diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 9c1ec4d13752..0a3b092e6b4f 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -840,9 +840,9 @@ public Optional convertPredicate(ConnectorSession sessi } @Override - public Optional convertProjection(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map assignments) { - return projectFunctionRewriter.rewrite(session, expression, assignments); + return projectFunctionRewriter.rewrite(session, handle, expression, assignments); } private static Optional toTypeHandle(DecimalType decimalType) diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteDotProductFunction.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteDotProductFunction.java index 700ca8b58174..6c2fd5908e3c 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteDotProductFunction.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteDotProductFunction.java @@ -22,6 +22,7 @@ import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.QueryParameter; import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FunctionName; @@ -60,7 +61,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext context) + public Optional rewrite(ConnectorTableHandle handle, ConnectorExpression projectionExpression, Captures captures, RewriteContext context) { ConnectorExpression call = captures.get(CALL); diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteStringReverseFunction.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteStringReverseFunction.java index a630a763f293..3428a0bc7162 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteStringReverseFunction.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteStringReverseFunction.java @@ -22,6 +22,7 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FunctionName; @@ -57,7 +58,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext context) + public Optional rewrite(ConnectorTableHandle handle, ConnectorExpression projectionExpression, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcTypeHandle typeHandle = ((JdbcColumnHandle) context.getAssignment(argument.getName())).getJdbcTypeHandle(); diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java index 91cde9783fc2..2c277159ae1c 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.java @@ -24,6 +24,7 @@ import io.trino.plugin.jdbc.QueryParameter; import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Constant; @@ -76,7 +77,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext context) + public Optional rewrite(ConnectorTableHandle handle, ConnectorExpression projectionExpression, Captures captures, RewriteContext context) { Optional leftExpression = rewrite(captures.get(LEFT_ARGUMENT), context); if (leftExpression.isEmpty()) {