From 134ed063215ae6bec0222b2ed23be2a0561fdbde Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 17 Mar 2023 22:11:19 +0100 Subject: [PATCH 1/2] Remove duplicate code branch The null case is handled 3 lines below. --- .../plugin/jdbc/expression/RewriteExactNumericConstant.java | 4 ---- .../trino/plugin/jdbc/expression/RewriteVarcharConstant.java | 3 --- .../trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java | 3 --- 3 files changed, 10 deletions(-) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java index 8a9115f0d1bf..ef531ee4cc8a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java @@ -46,10 +46,6 @@ public Pattern getPattern() @Override public Optional rewrite(Constant constant, Captures captures, RewriteContext context) { - if (constant.getValue() == null) { - return Optional.empty(); - } - Type type = constant.getType(); if (constant.getValue() == null) { return Optional.empty(); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java index 38a0f694a631..94ada1e9ecd4 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java @@ -39,9 +39,6 @@ public Pattern getPattern() @Override public Optional rewrite(Constant constant, Captures captures, RewriteContext context) { - if (constant.getValue() == null) { - return Optional.empty(); - } Slice slice = (Slice) constant.getValue(); if (slice == null) { return Optional.empty(); diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java index 71557c36f1ae..b31aa6bfa0c8 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java @@ -41,9 +41,6 @@ public Pattern getPattern() @Override public Optional rewrite(Constant constant, Captures captures, RewriteContext context) { - if (constant.getValue() == null) { - return Optional.empty(); - } Slice slice = (Slice) constant.getValue(); if (slice == null) { return Optional.empty(); From dc58b7879258000d2815698d8a85616e44d6927e Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 21 Mar 2023 22:13:47 +0100 Subject: [PATCH 2/2] Use JDBC parameters in JDBC complex expression pushdown --- .../io/trino/plugin/jdbc/BaseJdbcClient.java | 16 +- .../trino/plugin/jdbc/CachingJdbcClient.java | 5 +- .../plugin/jdbc/DefaultJdbcMetadata.java | 13 +- .../plugin/jdbc/DefaultQueryBuilder.java | 53 ++-- .../plugin/jdbc/ForwardingJdbcClient.java | 5 +- .../java/io/trino/plugin/jdbc/JdbcClient.java | 5 +- .../io/trino/plugin/jdbc/JdbcExpression.java | 14 +- .../io/trino/plugin/jdbc/JdbcTableHandle.java | 7 +- .../io/trino/plugin/jdbc/QueryBuilder.java | 7 +- .../io/trino/plugin/jdbc/QueryParameter.java | 27 +- .../aggregation/BaseImplementAvgBigint.java | 9 +- .../jdbc/aggregation/ImplementAvgDecimal.java | 9 +- .../ImplementAvgFloatingPoint.java | 9 +- .../jdbc/aggregation/ImplementCorr.java | 15 +- .../jdbc/aggregation/ImplementCount.java | 9 +- .../jdbc/aggregation/ImplementCountAll.java | 11 +- .../aggregation/ImplementCountDistinct.java | 9 +- .../aggregation/ImplementCovariancePop.java | 15 +- .../aggregation/ImplementCovarianceSamp.java | 15 +- .../jdbc/aggregation/ImplementMinMax.java | 9 +- .../aggregation/ImplementRegrIntercept.java | 15 +- .../jdbc/aggregation/ImplementRegrSlope.java | 15 +- .../jdbc/aggregation/ImplementStddevPop.java | 9 +- .../jdbc/aggregation/ImplementStddevSamp.java | 9 +- .../plugin/jdbc/aggregation/ImplementSum.java | 9 +- .../aggregation/ImplementVariancePop.java | 9 +- .../aggregation/ImplementVarianceSamp.java | 9 +- .../jdbc/expression/GenericRewrite.java | 22 +- ...dbcConnectorExpressionRewriterBuilder.java | 6 +- .../expression/ParameterizedExpression.java | 30 ++ .../jdbc/expression/RewriteComparison.java | 17 +- .../RewriteExactNumericConstant.java | 23 +- .../plugin/jdbc/expression/RewriteIn.java | 20 +- .../expression/RewriteLogicalExpression.java | 18 +- .../expression/RewriteVarcharConstant.java | 14 +- .../jdbc/expression/RewriteVariable.java | 7 +- .../jdbc/jmx/StatisticsAwareJdbcClient.java | 5 +- .../jdbc/TestDefaultJdbcQueryBuilder.java | 5 +- .../jdbc/expression/TestGenericRewrite.java | 19 +- .../plugin/clickhouse/ClickHouseClient.java | 7 +- .../trino/plugin/druid/DruidJdbcClient.java | 10 +- .../io/trino/plugin/ignite/IgniteClient.java | 9 +- .../plugin/ignite/ImplementAvgDecimal.java | 10 +- .../trino/plugin/mariadb/MariaDbClient.java | 7 +- .../io/trino/plugin/mysql/MySqlClient.java | 7 +- .../io/trino/plugin/oracle/OracleClient.java | 7 +- .../plugin/postgresql/PostgreSqlClient.java | 9 +- .../postgresql/TestPostgreSqlClient.java | 279 ++++++++++-------- .../redshift/ImplementRedshiftAvgDecimal.java | 13 +- .../trino/plugin/redshift/RedshiftClient.java | 7 +- .../sqlserver/ImplementSqlServerCountBig.java | 9 +- .../ImplementSqlServerCountBigAll.java | 11 +- .../ImplementSqlServerStddevPop.java | 9 +- .../sqlserver/ImplementSqlServerStdev.java | 9 +- .../sqlserver/ImplementSqlServerVariance.java | 9 +- .../ImplementSqlServerVariancePop.java | 9 +- .../RewriteUnicodeVarcharConstant.java | 58 ---- .../plugin/sqlserver/SqlServerClient.java | 12 +- 58 files changed, 617 insertions(+), 397 deletions(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ParameterizedExpression.java delete mode 100644 plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index 3b1ee8815dc6..76f595b01f56 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableSortedSet; import com.google.common.io.Closer; import io.airlift.log.Logger; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -440,7 +441,7 @@ public PreparedQuery prepareQuery( JdbcTableHandle table, Optional>> groupingSets, List columns, - Map columnExpressions) + Map columnExpressions) { verify(table.getAuthorization().isEmpty(), "Unexpected authorization is required for table: %s".formatted(table)); try (Connection connection = connectionFactory.openConnection(session)) { @@ -465,7 +466,7 @@ protected PreparedQuery prepareQuery( JdbcTableHandle table, Optional>> groupingSets, List columns, - Map columnExpressions, + Map columnExpressions, Optional split) { return applyQueryTransformations(table, queryBuilder.prepareSelectQuery( @@ -480,15 +481,18 @@ protected PreparedQuery prepareQuery( getAdditionalPredicate(table.getConstraintExpressions(), split.flatMap(JdbcSplit::getAdditionalPredicate)))); } - protected static Optional getAdditionalPredicate(List constraintExpressions, Optional splitPredicate) + protected static Optional getAdditionalPredicate(List constraintExpressions, Optional splitPredicate) { if (constraintExpressions.isEmpty() && splitPredicate.isEmpty()) { return Optional.empty(); } - return Optional.of( - Stream.concat(constraintExpressions.stream(), splitPredicate.stream()) - .collect(joining(") AND (", "(", ")"))); + return Optional.of(new ParameterizedExpression( + Stream.concat(constraintExpressions.stream().map(ParameterizedExpression::expression), splitPredicate.stream()) + .collect(joining(") AND (", "(", ")")), + constraintExpressions.stream() + .flatMap(expressionRewrite -> expressionRewrite.parameters().stream()) + .collect(toImmutableList()))); } @Override diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index 3dc3a6ab087d..2d6a6663bc85 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -23,6 +23,7 @@ import io.trino.collect.cache.EvictableCacheBuilder; import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.plugin.jdbc.IdentityCacheMapping.IdentityCacheKey; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -224,7 +225,7 @@ public Optional implementAggregation(ConnectorSession session, A } @Override - public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { return delegate.convertPredicate(session, expression, assignments); } @@ -255,7 +256,7 @@ public PreparedQuery prepareQuery( JdbcTableHandle table, Optional>> groupingSets, List columns, - Map columnExpressions) + Map columnExpressions) { return delegate.prepareQuery(session, table, groupingSets, columns, columnExpressions); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 41496a7b2e26..ccb9222afa21 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.trino.plugin.jdbc.PredicatePushdownController.DomainPushdownResult; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.ptf.Query.QueryFunctionHandle; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -157,7 +158,7 @@ public Optional> applyFilter(C TupleDomain oldDomain = handle.getConstraint(); TupleDomain newDomain = oldDomain.intersect(constraint.getSummary()); - List newConstraintExpressions; + List newConstraintExpressions; TupleDomain remainingFilter; Optional remainingExpression; if (newDomain.isNone()) { @@ -190,10 +191,10 @@ public Optional> applyFilter(C remainingFilter = TupleDomain.withColumnDomains(unsupported); if (isComplexExpressionPushdown(session)) { - List newExpressions = new ArrayList<>(); + List newExpressions = new ArrayList<>(); List remainingExpressions = new ArrayList<>(); for (ConnectorExpression expression : extractConjuncts(constraint.getExpression())) { - Optional converted = jdbcClient.convertPredicate(session, expression, constraint.getAssignments()); + Optional converted = jdbcClient.convertPredicate(session, expression, constraint.getAssignments()); if (converted.isPresent()) { newExpressions.add(converted.get()); } @@ -201,7 +202,7 @@ public Optional> applyFilter(C remainingExpressions.add(expression); } } - newConstraintExpressions = ImmutableSet.builder() + newConstraintExpressions = ImmutableSet.builder() .addAll(handle.getConstraintExpressions()) .addAll(newExpressions) .build().asList(); @@ -337,7 +338,7 @@ public Optional> applyAggrega ImmutableList.Builder newColumns = ImmutableList.builder(); ImmutableList.Builder projections = ImmutableList.builder(); ImmutableList.Builder resultAssignments = ImmutableList.builder(); - ImmutableMap.Builder expressions = ImmutableMap.builder(); + ImmutableMap.Builder expressions = ImmutableMap.builder(); List> groupingSetsAsJdbcColumnHandles = groupingSets.stream() .map(groupingSet -> groupingSet.stream() @@ -374,7 +375,7 @@ public Optional> applyAggrega newColumns.add(newColumn); projections.add(new Variable(newColumn.getColumnName(), aggregate.getOutputType())); resultAssignments.add(new Assignment(newColumn.getColumnName(), newColumn, aggregate.getOutputType())); - expressions.put(columnName, expression.get().getExpression()); + expressions.put(columnName, new ParameterizedExpression(expression.get().getExpression(), expression.get().getParameters())); } List newColumnsList = newColumns.build(); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java index c6b382fe886f..f8e7ddf8eb69 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -19,6 +19,7 @@ import com.google.inject.Inject; import io.airlift.log.Logger; import io.airlift.slice.Slice; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; @@ -72,9 +73,9 @@ public PreparedQuery prepareSelectQuery( JdbcRelationHandle baseRelation, Optional>> groupingSets, List columns, - Map columnExpressions, + Map columnExpressions, TupleDomain tupleDomain, - Optional additionalPredicate) + Optional additionalPredicate) { if (!tupleDomain.isNone()) { Map domains = tupleDomain.getDomains().orElseThrow(); @@ -88,11 +89,14 @@ public PreparedQuery prepareSelectQuery( ImmutableList.Builder conjuncts = ImmutableList.builder(); ImmutableList.Builder accumulator = ImmutableList.builder(); - String sql = "SELECT " + getProjection(client, columns, columnExpressions); + String sql = "SELECT " + getProjection(client, columns, columnExpressions, accumulator::add); sql += getFrom(client, baseRelation, accumulator::add); toConjuncts(client, session, connection, tupleDomain, conjuncts, accumulator::add); - additionalPredicate.ifPresent(conjuncts::add); + additionalPredicate.ifPresent(predicate -> { + conjuncts.add(predicate.expression()); + accumulator.addAll(predicate.parameters()); + }); List clauses = conjuncts.build(); if (!clauses.isEmpty()) { sql += " WHERE " + Joiner.on(" AND ").join(clauses); @@ -150,7 +154,7 @@ public PreparedQuery prepareDeleteQuery( Connection connection, JdbcNamedRelationHandle baseRelation, TupleDomain tupleDomain, - Optional additionalPredicate) + Optional additionalPredicate) { String sql = "DELETE FROM " + getRelation(client, baseRelation.getRemoteTableName()); @@ -158,7 +162,10 @@ public PreparedQuery prepareDeleteQuery( ImmutableList.Builder accumulator = ImmutableList.builder(); toConjuncts(client, session, connection, tupleDomain, conjuncts, accumulator::add); - additionalPredicate.ifPresent(conjuncts::add); + additionalPredicate.ifPresent(predicate -> { + conjuncts.add(predicate.expression()); + accumulator.addAll(predicate.parameters()); + }); List clauses = conjuncts.build(); if (!clauses.isEmpty()) { sql += " WHERE " + Joiner.on(" AND ").join(clauses); @@ -182,7 +189,9 @@ public PreparedStatement prepareStatement( for (int i = 0; i < parameters.size(); i++) { QueryParameter parameter = parameters.get(i); int parameterIndex = i + 1; - WriteFunction writeFunction = getWriteFunction(client, session, connection, parameter.getJdbcType(), parameter.getType()); + WriteFunction writeFunction = parameter.getJdbcType() + .map(jdbcType -> getWriteFunction(client, session, connection, jdbcType, parameter.getType())) + .orElseGet(() -> getWriteFunction(client, session, parameter.getType())); Class javaType = writeFunction.getJavaType(); Object value = parameter.getValue() // The value must be present, since DefaultQueryBuilder never creates null parameters. Values coming from Domain's ValueSet are non-null, and @@ -251,21 +260,24 @@ protected String getRelation(JdbcClient client, RemoteTableName remoteTableName) return client.quoted(remoteTableName); } - protected String getProjection(JdbcClient client, List columns, Map columnExpressions) + protected String getProjection(JdbcClient client, List columns, Map columnExpressions, Consumer accumulator) { if (columns.isEmpty()) { return "1 x"; } - return columns.stream() - .map(jdbcColumnHandle -> { - String columnAlias = client.quoted(jdbcColumnHandle.getColumnName()); - String expression = columnExpressions.get(jdbcColumnHandle.getColumnName()); - if (expression == null) { - return columnAlias; - } - return format("%s AS %s", expression, columnAlias); - }) - .collect(joining(", ")); + List projections = new ArrayList<>(); + for (JdbcColumnHandle jdbcColumnHandle : columns) { + String columnAlias = client.quoted(jdbcColumnHandle.getColumnName()); + ParameterizedExpression expression = columnExpressions.get(jdbcColumnHandle.getColumnName()); + if (expression == null) { + projections.add(columnAlias); + } + else { + projections.add(format("%s AS %s", expression.expression(), columnAlias)); + expression.parameters().forEach(accumulator); + } + } + return String.join(", ", projections); } private String getFrom(JdbcClient client, JdbcRelationHandle baseRelation, Consumer accumulator) @@ -425,4 +437,9 @@ private static WriteFunction getWriteFunction(JdbcClient client, ConnectorSessio verify(writeFunction.getJavaType() == type.getJavaType(), "Java type mismatch: %s, %s", writeFunction, type); return writeFunction; } + + private static WriteFunction getWriteFunction(JdbcClient client, ConnectorSession session, Type type) + { + return client.toWriteMapping(session, type).getWriteFunction(); + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index 531353a15dd4..10fd7c99c8df 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; @@ -127,7 +128,7 @@ public Optional implementAggregation(ConnectorSession session, A } @Override - public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { return delegate().convertPredicate(session, expression, assignments); } @@ -158,7 +159,7 @@ public PreparedQuery prepareQuery( JdbcTableHandle table, Optional>> groupingSets, List columns, - Map columnExpressions) + Map columnExpressions) { return delegate().prepareQuery(session, table, groupingSets, columns, columnExpressions); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index d5e683d8c4bc..2da9e7917e4d 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -78,7 +79,7 @@ default Optional implementAggregation(ConnectorSession session, return Optional.empty(); } - default Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + default Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { return Optional.empty(); } @@ -99,7 +100,7 @@ PreparedQuery prepareQuery( JdbcTableHandle table, Optional>> groupingSets, List columns, - Map columnExpressions); + Map columnExpressions); PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List columns) throws SQLException; diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcExpression.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcExpression.java index 15315f794d79..7c51231fb161 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcExpression.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcExpression.java @@ -13,17 +13,23 @@ */ package io.trino.plugin.jdbc; +import com.google.common.collect.ImmutableList; + +import java.util.List; + import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; public final class JdbcExpression { private final String expression; + private final List parameters; private final JdbcTypeHandle jdbcTypeHandle; - public JdbcExpression(String expression, JdbcTypeHandle jdbcTypeHandle) + public JdbcExpression(String expression, List parameters, JdbcTypeHandle jdbcTypeHandle) { this.expression = requireNonNull(expression, "expression is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); this.jdbcTypeHandle = requireNonNull(jdbcTypeHandle, "jdbcTypeHandle is null"); } @@ -32,6 +38,11 @@ public String getExpression() return expression; } + public List getParameters() + { + return parameters; + } + public JdbcTypeHandle getJdbcTypeHandle() { return jdbcTypeHandle; @@ -42,6 +53,7 @@ public String toString() { return toStringHelper(this) .add("expression", expression) + .add("parameters", parameters) .add("jdbcTypeHandle", jdbcTypeHandle) .toString(); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java index 44112d9eb81a..c9fbf1256cc4 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.SchemaTableName; @@ -40,7 +41,7 @@ public final class JdbcTableHandle private final TupleDomain constraint; // Additional to constraint - private final List constraintExpressions; + private final List constraintExpressions; // semantically sort order is applied after constraint private final Optional> sortOrder; @@ -78,7 +79,7 @@ public JdbcTableHandle(SchemaTableName schemaTableName, RemoteTableName remoteTa public JdbcTableHandle( @JsonProperty("relationHandle") JdbcRelationHandle relationHandle, @JsonProperty("constraint") TupleDomain constraint, - @JsonProperty("constraintExpressions") List constraintExpressions, + @JsonProperty("constraintExpressions") List constraintExpressions, @JsonProperty("sortOrder") Optional> sortOrder, @JsonProperty("limit") OptionalLong limit, @JsonProperty("columns") Optional> columns, @@ -138,7 +139,7 @@ public TupleDomain getConstraint() } @JsonProperty - public List getConstraintExpressions() + public List getConstraintExpressions() { return constraintExpressions; } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java index d934c1d9387e..63338f6f72d9 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.JoinType; @@ -34,9 +35,9 @@ PreparedQuery prepareSelectQuery( JdbcRelationHandle baseRelation, Optional>> groupingSets, List columns, - Map columnExpressions, + Map columnExpressions, TupleDomain tupleDomain, - Optional additionalPredicate); + Optional additionalPredicate); PreparedQuery prepareJoinQuery( JdbcClient client, @@ -55,7 +56,7 @@ PreparedQuery prepareDeleteQuery( Connection connection, JdbcNamedRelationHandle baseRelation, TupleDomain tupleDomain, - Optional additionalPredicate); + Optional additionalPredicate); PreparedStatement prepareStatement( JdbcClient client, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryParameter.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryParameter.java index 8c69fe05c9d5..b2180e6eab3f 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryParameter.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryParameter.java @@ -22,17 +22,28 @@ import java.util.Objects; import java.util.Optional; +import static com.google.common.base.MoreObjects.toStringHelper; import static io.trino.spi.predicate.Utils.blockToNativeValue; import static io.trino.spi.predicate.Utils.nativeValueToBlock; import static java.util.Objects.requireNonNull; public final class QueryParameter { - private final JdbcTypeHandle jdbcType; + private final Optional jdbcType; private final Type type; private final Optional value; + public QueryParameter(Type type, Optional value) + { + this(Optional.empty(), type, value); + } + public QueryParameter(JdbcTypeHandle jdbcType, Type type, Optional value) + { + this(Optional.of(jdbcType), type, value); + } + + private QueryParameter(Optional jdbcType, Type type, Optional value) { this.jdbcType = requireNonNull(jdbcType, "jdbcType is null"); this.type = requireNonNull(type, "type is null"); @@ -40,7 +51,7 @@ public QueryParameter(JdbcTypeHandle jdbcType, Type type, Optional value } @JsonCreator - public static QueryParameter fromValueAsBlock(JdbcTypeHandle jdbcType, Type type, Block valueBlock) + public static QueryParameter fromValueAsBlock(Optional jdbcType, Type type, Block valueBlock) { requireNonNull(type, "type is null"); requireNonNull(valueBlock, "valueBlock is null"); @@ -49,7 +60,7 @@ public static QueryParameter fromValueAsBlock(JdbcTypeHandle jdbcType, Type type } @JsonProperty - public JdbcTypeHandle getJdbcType() + public Optional getJdbcType() { return jdbcType; } @@ -92,4 +103,14 @@ public int hashCode() { return Objects.hash(jdbcType, type, value); } + + @Override + public String toString() + { + return toStringHelper(this) + .add("jdbcType", jdbcType) + .add("type", type) + .add("value", value) + .toString(); + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/BaseImplementAvgBigint.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/BaseImplementAvgBigint.java index 314f53af292a..50c33eb3c429 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/BaseImplementAvgBigint.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/BaseImplementAvgBigint.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -42,7 +43,7 @@ * can result in rounding of the output to a bigint. */ public abstract class BaseImplementAvgBigint - implements AggregateFunctionRule + implements AggregateFunctionRule { private final Capture argument; @@ -63,13 +64,15 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(this.argument); verify(aggregateFunction.getOutputType() == DOUBLE); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format(getRewriteFormatExpression(), context.rewriteExpression(argument).orElseThrow()), + format(getRewriteFormatExpression(), rewrittenArgument.expression()), + rewrittenArgument.parameters(), new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementAvgDecimal.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementAvgDecimal.java index 5dbfb3a5c595..4dba2d0740c0 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementAvgDecimal.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementAvgDecimal.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DecimalType; @@ -38,7 +39,7 @@ * Implements {@code avg(decimal(p, s)} */ public class ImplementAvgDecimal - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -54,15 +55,17 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); DecimalType type = (DecimalType) columnHandle.getColumnType(); verify(aggregateFunction.getOutputType().equals(type)); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("CAST(avg(%s) AS decimal(%s, %s))", context.rewriteExpression(argument).orElseThrow(), type.getPrecision(), type.getScale()), + format("CAST(avg(%s) AS decimal(%s, %s))", rewrittenArgument.expression(), type.getPrecision(), type.getScale()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementAvgFloatingPoint.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementAvgFloatingPoint.java index 1364cb128636..b95783889c9b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementAvgFloatingPoint.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementAvgFloatingPoint.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -39,7 +40,7 @@ * Implements {@code avg(float)} */ public class ImplementAvgFloatingPoint - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -55,14 +56,16 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("avg(%s)", context.rewriteExpression(argument).orElseThrow()), + format("avg(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCorr.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCorr.java index 7b5c8a0b0066..ba75f13d2f0b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCorr.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCorr.java @@ -13,12 +13,15 @@ */ package io.trino.plugin.jdbc.aggregation; +import com.google.common.collect.ImmutableList; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -37,7 +40,7 @@ import static java.lang.String.format; public class ImplementCorr - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture> ARGUMENTS = newCapture(); @@ -53,7 +56,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { List arguments = captures.get(ARGUMENTS); verify(arguments.size() == 2); @@ -63,8 +66,14 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(argument1.getName()); verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType())); + ParameterizedExpression rewrittenArgument1 = context.rewriteExpression(argument1).orElseThrow(); + ParameterizedExpression rewrittenArgument2 = context.rewriteExpression(argument2).orElseThrow(); return Optional.of(new JdbcExpression( - format("corr(%s, %s)", context.rewriteExpression(argument1).orElseThrow(), context.rewriteExpression(argument2).orElseThrow()), + format("corr(%s, %s)", rewrittenArgument1.expression(), rewrittenArgument2.expression()), + ImmutableList.builder() + .addAll(rewrittenArgument1.parameters()) + .addAll(rewrittenArgument2.parameters()) + .build(), columnHandle1.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCount.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCount.java index 219265a5d41e..8b6864a62e98 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCount.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCount.java @@ -20,6 +20,7 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.BigintType; @@ -40,7 +41,7 @@ * Implements {@code count(x)}. */ public class ImplementCount - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -63,13 +64,15 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); verify(aggregateFunction.getOutputType() == BIGINT); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("count(%s)", context.rewriteExpression(argument).orElseThrow()), + format("count(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), bigintTypeHandle)); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCountAll.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCountAll.java index 314d3a2e2a7d..8d38999e70d2 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCountAll.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCountAll.java @@ -13,12 +13,14 @@ */ package io.trino.plugin.jdbc.aggregation; +import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.type.BigintType; @@ -36,7 +38,7 @@ * Implements {@code count(*)}. */ public class ImplementCountAll - implements AggregateFunctionRule + implements AggregateFunctionRule { private final JdbcTypeHandle bigintTypeHandle; @@ -57,9 +59,12 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { verify(aggregateFunction.getOutputType() == BIGINT); - return Optional.of(new JdbcExpression("count(*)", bigintTypeHandle)); + return Optional.of(new JdbcExpression( + "count(*)", + ImmutableList.of(), + bigintTypeHandle)); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCountDistinct.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCountDistinct.java index 00d84fad5dcc..f661479c693b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCountDistinct.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCountDistinct.java @@ -21,6 +21,7 @@ import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.BigintType; @@ -44,7 +45,7 @@ * Implements {@code count(DISTINCT x)}. */ public class ImplementCountDistinct - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -71,7 +72,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); @@ -83,8 +84,10 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap return Optional.empty(); } + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("count(DISTINCT %s)", context.rewriteExpression(argument).orElseThrow()), + format("count(DISTINCT %s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), bigintTypeHandle)); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCovariancePop.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCovariancePop.java index c7b8d5d78057..b135e5911297 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCovariancePop.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCovariancePop.java @@ -13,12 +13,15 @@ */ package io.trino.plugin.jdbc.aggregation; +import com.google.common.collect.ImmutableList; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -37,7 +40,7 @@ import static java.lang.String.format; public class ImplementCovariancePop - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture> ARGUMENTS = newCapture(); @@ -53,7 +56,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { List arguments = captures.get(ARGUMENTS); verify(arguments.size() == 2); @@ -63,8 +66,14 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(argument1.getName()); verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType())); + ParameterizedExpression rewrittenArgument1 = context.rewriteExpression(argument1).orElseThrow(); + ParameterizedExpression rewrittenArgument2 = context.rewriteExpression(argument2).orElseThrow(); return Optional.of(new JdbcExpression( - format("covar_pop(%s, %s)", context.rewriteExpression(argument1).orElseThrow(), context.rewriteExpression(argument2).orElseThrow()), + format("covar_pop(%s, %s)", rewrittenArgument1.expression(), rewrittenArgument2.expression()), + ImmutableList.builder() + .addAll(rewrittenArgument1.parameters()) + .addAll(rewrittenArgument2.parameters()) + .build(), columnHandle1.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCovarianceSamp.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCovarianceSamp.java index 2d8f71d43b7f..9118883c2274 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCovarianceSamp.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementCovarianceSamp.java @@ -13,12 +13,15 @@ */ package io.trino.plugin.jdbc.aggregation; +import com.google.common.collect.ImmutableList; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -37,7 +40,7 @@ import static java.lang.String.format; public class ImplementCovarianceSamp - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture> ARGUMENTS = newCapture(); @@ -53,7 +56,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { List arguments = captures.get(ARGUMENTS); verify(arguments.size() == 2); @@ -63,8 +66,14 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(argument1.getName()); verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType())); + ParameterizedExpression rewrittenArgument1 = context.rewriteExpression(argument1).orElseThrow(); + ParameterizedExpression rewrittenArgument2 = context.rewriteExpression(argument2).orElseThrow(); return Optional.of(new JdbcExpression( - format("covar_samp(%s, %s)", context.rewriteExpression(argument1).orElseThrow(), context.rewriteExpression(argument2).orElseThrow()), + format("covar_samp(%s, %s)", rewrittenArgument1.expression(), rewrittenArgument2.expression()), + ImmutableList.builder() + .addAll(rewrittenArgument1.parameters()) + .addAll(rewrittenArgument2.parameters()) + .build(), columnHandle1.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementMinMax.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementMinMax.java index a62ae9e04310..5e88b4aaf28a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementMinMax.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementMinMax.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.CharType; @@ -39,7 +40,7 @@ * Implements {@code min(x)}, {@code max(x)}. */ public class ImplementMinMax - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -59,7 +60,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); @@ -70,8 +71,10 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap return Optional.empty(); } + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("%s(%s)", aggregateFunction.getFunctionName(), context.rewriteExpression(argument).orElseThrow()), + format("%s(%s)", aggregateFunction.getFunctionName(), rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementRegrIntercept.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementRegrIntercept.java index be3497ece1cd..89d6cd11b76b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementRegrIntercept.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementRegrIntercept.java @@ -13,12 +13,15 @@ */ package io.trino.plugin.jdbc.aggregation; +import com.google.common.collect.ImmutableList; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -37,7 +40,7 @@ import static java.lang.String.format; public class ImplementRegrIntercept - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture> ARGUMENTS = newCapture(); @@ -53,7 +56,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { List arguments = captures.get(ARGUMENTS); verify(arguments.size() == 2); @@ -63,8 +66,14 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(argument1.getName()); verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType())); + ParameterizedExpression rewrittenArgument1 = context.rewriteExpression(argument1).orElseThrow(); + ParameterizedExpression rewrittenArgument2 = context.rewriteExpression(argument2).orElseThrow(); return Optional.of(new JdbcExpression( - format("regr_intercept(%s, %s)", context.rewriteExpression(argument1).orElseThrow(), context.rewriteExpression(argument2).orElseThrow()), + format("regr_intercept(%s, %s)", rewrittenArgument1.expression(), rewrittenArgument2.expression()), + ImmutableList.builder() + .addAll(rewrittenArgument1.parameters()) + .addAll(rewrittenArgument2.parameters()) + .build(), columnHandle1.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementRegrSlope.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementRegrSlope.java index f294042e31a5..cc9b2486dc3b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementRegrSlope.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementRegrSlope.java @@ -13,12 +13,15 @@ */ package io.trino.plugin.jdbc.aggregation; +import com.google.common.collect.ImmutableList; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -37,7 +40,7 @@ import static java.lang.String.format; public class ImplementRegrSlope - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture> ARGUMENTS = newCapture(); @@ -53,7 +56,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { List arguments = captures.get(ARGUMENTS); verify(arguments.size() == 2); @@ -63,8 +66,14 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(argument1.getName()); verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType())); + ParameterizedExpression rewrittenArgument1 = context.rewriteExpression(argument1).orElseThrow(); + ParameterizedExpression rewrittenArgument2 = context.rewriteExpression(argument2).orElseThrow(); return Optional.of(new JdbcExpression( - format("regr_slope(%s, %s)", context.rewriteExpression(argument1).orElseThrow(), context.rewriteExpression(argument2).orElseThrow()), + format("regr_slope(%s, %s)", rewrittenArgument1.expression(), rewrittenArgument2.expression()), + ImmutableList.builder() + .addAll(rewrittenArgument1.parameters()) + .addAll(rewrittenArgument2.parameters()) + .build(), columnHandle1.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementStddevPop.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementStddevPop.java index c543ccf5a358..2ddd1bf36b72 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementStddevPop.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementStddevPop.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -35,7 +36,7 @@ import static java.lang.String.format; public class ImplementStddevPop - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -51,14 +52,16 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("stddev_pop(%s)", context.rewriteExpression(argument).orElseThrow()), + format("stddev_pop(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementStddevSamp.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementStddevSamp.java index 6696d091a27b..3d02d7c9e694 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementStddevSamp.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementStddevSamp.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -35,7 +36,7 @@ import static java.lang.String.format; public class ImplementStddevSamp - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -51,14 +52,16 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("stddev_samp(%s)", context.rewriteExpression(argument).orElseThrow()), + format("stddev_samp(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementSum.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementSum.java index 07f577823a48..10a2b2459fad 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementSum.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementSum.java @@ -20,6 +20,7 @@ import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DecimalType; @@ -39,7 +40,7 @@ * Implements {@code sum(x)} */ public class ImplementSum - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -59,7 +60,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); @@ -79,8 +80,10 @@ else if (aggregateFunction.getOutputType() instanceof DecimalType) { return Optional.empty(); } + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("sum(%s)", context.rewriteExpression(argument).orElseThrow()), + format("sum(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), resultTypeHandle)); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementVariancePop.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementVariancePop.java index b996fd1a3a36..f3c78782de14 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementVariancePop.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementVariancePop.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -35,7 +36,7 @@ import static java.lang.String.format; public class ImplementVariancePop - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -51,14 +52,16 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("var_pop(%s)", context.rewriteExpression(argument).orElseThrow()), + format("var_pop(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementVarianceSamp.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementVarianceSamp.java index d1d3a40cb562..943a3d732710 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementVarianceSamp.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementVarianceSamp.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -35,7 +36,7 @@ import static java.lang.String.format; public class ImplementVarianceSamp - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -51,14 +52,16 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("var_samp(%s)", context.rewriteExpression(argument).orElseThrow()), + format("var_samp(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/GenericRewrite.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/GenericRewrite.java index 9876b3d713a7..f562ff2f1491 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/GenericRewrite.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/GenericRewrite.java @@ -13,8 +13,10 @@ */ package io.trino.plugin.jdbc.expression; +import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; import io.trino.spi.expression.ConnectorExpression; import java.util.Map; @@ -28,7 +30,7 @@ import static java.util.regex.Matcher.quoteReplacement; public class GenericRewrite - implements ConnectorExpressionRule + implements ConnectorExpressionRule { // Matches words in the `rewritePattern` private static final Pattern REWRITE_TOKENS = Pattern.compile("(? getPattern() } @Override - public Optional rewrite(ConnectorExpression expression, Captures captures, RewriteContext context) + public Optional rewrite(ConnectorExpression expression, Captures captures, RewriteContext context) { MatchContext matchContext = new MatchContext(); expressionPattern.resolve(captures, matchContext); - StringBuilder rewritten = new StringBuilder(); + StringBuilder result = new StringBuilder(); + ImmutableList.Builder parameters = ImmutableList.builder(); Matcher matcher = REWRITE_TOKENS.matcher(rewritePattern); while (matcher.find()) { String identifier = matcher.group(0); @@ -69,11 +72,12 @@ public Optional rewrite(ConnectorExpression expression, Captures capture replacement = Long.toString((Long) value); } else if (value instanceof ConnectorExpression) { - Optional rewrittenExpression = context.defaultRewrite((ConnectorExpression) value); - if (rewrittenExpression.isEmpty()) { + Optional rewritten = context.defaultRewrite((ConnectorExpression) value); + if (rewritten.isEmpty()) { return Optional.empty(); } - replacement = format("(%s)", rewrittenExpression.get()); + replacement = format("(%s)", rewritten.get().expression()); + parameters.addAll(rewritten.get().parameters()); } else { throw new UnsupportedOperationException(format("Unsupported value: %s (%s)", value, value.getClass())); @@ -82,11 +86,11 @@ else if (value instanceof ConnectorExpression) { else { replacement = identifier; } - matcher.appendReplacement(rewritten, quoteReplacement(replacement)); + matcher.appendReplacement(result, quoteReplacement(replacement)); } - matcher.appendTail(rewritten); + matcher.appendTail(result); - return Optional.of(rewritten.toString()); + return Optional.of(new ParameterizedExpression(result.toString(), parameters.build())); } @Override diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java index 3e41269befa1..8256ff3fa077 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java @@ -33,7 +33,7 @@ public static JdbcConnectorExpressionRewriterBuilder newBuilder() return new JdbcConnectorExpressionRewriterBuilder(); } - private final ImmutableSet.Builder> rules = ImmutableSet.builder(); + private final ImmutableSet.Builder> rules = ImmutableSet.builder(); private final Map> typeClasses = new HashMap<>(); private JdbcConnectorExpressionRewriterBuilder() {} @@ -49,7 +49,7 @@ public JdbcConnectorExpressionRewriterBuilder addStandardRules(Function rule) + public JdbcConnectorExpressionRewriterBuilder add(ConnectorExpressionRule rule) { rules.add(rule); return this; @@ -77,7 +77,7 @@ public JdbcConnectorExpressionRewriterBuilder to(String rewritePattern) }; } - public ConnectorExpressionRewriter build() + public ConnectorExpressionRewriter build() { return new ConnectorExpressionRewriter<>(rules.build()); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ParameterizedExpression.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ParameterizedExpression.java new file mode 100644 index 000000000000..eee8b8132c9c --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ParameterizedExpression.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.plugin.jdbc.QueryParameter; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public record ParameterizedExpression(String expression, List parameters) +{ + public ParameterizedExpression + { + requireNonNull(expression, "expression is null"); + parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteComparison.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteComparison.java index e49d15995935..10db5e84ab6c 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteComparison.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteComparison.java @@ -13,10 +13,12 @@ */ package io.trino.plugin.jdbc.expression; +import com.google.common.collect.ImmutableList; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FunctionName; @@ -50,7 +52,7 @@ import static java.util.function.Function.identity; public class RewriteComparison - implements ConnectorExpressionRule + implements ConnectorExpressionRule { private static final Capture LEFT = newCapture(); private static final Capture RIGHT = newCapture(); @@ -117,18 +119,23 @@ public Pattern getPattern() } @Override - public Optional rewrite(Call call, Captures captures, RewriteContext context) + public Optional rewrite(Call call, Captures captures, RewriteContext context) { - Optional left = context.defaultRewrite(captures.get(LEFT)); + Optional left = context.defaultRewrite(captures.get(LEFT)); if (left.isEmpty()) { return Optional.empty(); } - Optional right = context.defaultRewrite(captures.get(RIGHT)); + Optional right = context.defaultRewrite(captures.get(RIGHT)); if (right.isEmpty()) { return Optional.empty(); } verify(call.getFunctionName().getCatalogSchema().isEmpty()); // filtered out by the pattern ComparisonOperator operator = ComparisonOperator.forFunctionName(call.getFunctionName()); - return Optional.of(format("(%s) %s (%s)", left.get(), operator.getOperator(), right.get())); + return Optional.of(new ParameterizedExpression( + format("(%s) %s (%s)", left.get().expression(), operator.getOperator(), right.get().expression()), + ImmutableList.builder() + .addAll(left.get().parameters()) + .addAll(right.get().parameters()) + .build())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java index ef531ee4cc8a..83bccc20e7c8 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java @@ -13,13 +13,13 @@ */ package io.trino.plugin.jdbc.expression; +import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; import io.trino.spi.expression.Constant; import io.trino.spi.type.DecimalType; -import io.trino.spi.type.Decimals; -import io.trino.spi.type.Int128; import io.trino.spi.type.Type; import java.util.Optional; @@ -32,7 +32,7 @@ import static io.trino.spi.type.TinyintType.TINYINT; public class RewriteExactNumericConstant - implements ConnectorExpressionRule + implements ConnectorExpressionRule { private static final Pattern PATTERN = constant().with(type().matching(type -> type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT || type instanceof DecimalType)); @@ -44,21 +44,16 @@ public Pattern getPattern() } @Override - public Optional rewrite(Constant constant, Captures captures, RewriteContext context) + public Optional rewrite(Constant constant, Captures captures, RewriteContext context) { Type type = constant.getType(); - if (constant.getValue() == null) { + Object value = constant.getValue(); + if (value == null) { + // TODO we could handle NULL values too return Optional.empty(); } - if (type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT) { - return Optional.of(Long.toString((long) constant.getValue())); - } - - if (type instanceof DecimalType decimalType) { - if (decimalType.isShort()) { - return Optional.of(Decimals.toString((long) constant.getValue(), decimalType.getScale())); - } - return Optional.of(Decimals.toString((Int128) constant.getValue(), decimalType.getScale())); + if (type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT || type instanceof DecimalType) { + return Optional.of(new ParameterizedExpression("?", ImmutableList.of(new QueryParameter(type, Optional.of(value))))); } throw new UnsupportedOperationException("Unsupported type: " + type); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java index 0b81f689b502..e686e7685fd3 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java @@ -19,6 +19,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; @@ -41,7 +42,7 @@ import static java.lang.String.format; public class RewriteIn - implements ConnectorExpressionRule + implements ConnectorExpressionRule { private static final Capture VALUE = newCapture(); private static final Capture> EXPRESSIONS = newCapture(); @@ -60,9 +61,9 @@ public Pattern getPattern() } @Override - public Optional rewrite(Call call, Captures captures, RewriteContext context) + public Optional rewrite(Call call, Captures captures, RewriteContext context) { - Optional value = context.defaultRewrite(captures.get(VALUE)); + Optional value = context.defaultRewrite(captures.get(VALUE)); if (value.isEmpty()) { return Optional.empty(); } @@ -73,17 +74,22 @@ public Optional rewrite(Call call, Captures captures, RewriteContext parameters = ImmutableList.builder(); + parameters.addAll(value.get().parameters()); ImmutableList.Builder rewrittenValues = ImmutableList.builderWithExpectedSize(expressions.size()); for (ConnectorExpression expression : expressions) { - Optional rewrittenExpression = context.defaultRewrite(expression); - if (rewrittenExpression.isEmpty()) { + Optional rewritten = context.defaultRewrite(expression); + if (rewritten.isEmpty()) { return Optional.empty(); } - rewrittenValues.add(rewrittenExpression.get()); + rewrittenValues.add(rewritten.get().expression()); + parameters.addAll(rewritten.get().parameters()); } List values = rewrittenValues.build(); verify(!values.isEmpty(), "Empty values"); - return Optional.of(format("(%s) IN (%s)", value.get(), Joiner.on(", ").join(values))); + return Optional.of(new ParameterizedExpression( + format("(%s) IN (%s)", value.get().expression(), Joiner.on(", ").join(values)), + parameters.build())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLogicalExpression.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLogicalExpression.java index 2d1ea801ea03..7c4507acfda4 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLogicalExpression.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLogicalExpression.java @@ -13,9 +13,11 @@ */ package io.trino.plugin.jdbc.expression; +import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FunctionName; @@ -33,7 +35,7 @@ import static java.util.Objects.requireNonNull; class RewriteLogicalExpression - implements ConnectorExpressionRule + implements ConnectorExpressionRule { private final Pattern pattern; private final String operator; @@ -53,21 +55,25 @@ public Pattern getPattern() } @Override - public Optional rewrite(Call call, Captures captures, RewriteContext context) + public Optional rewrite(Call call, Captures captures, RewriteContext context) { List arguments = call.getArguments(); verify(!arguments.isEmpty(), "no arguments"); List terms = new ArrayList<>(arguments.size()); + ImmutableList.Builder parameters = ImmutableList.builder(); for (ConnectorExpression argument : arguments) { verify(argument.getType() == BOOLEAN, "Unexpected type of argument: %s", argument.getType()); - Optional rewritten = context.defaultRewrite(argument); + Optional rewritten = context.defaultRewrite(argument); if (rewritten.isEmpty()) { return Optional.empty(); } - terms.add(rewritten.get()); + terms.add(rewritten.get().expression()); + parameters.addAll(rewritten.get().parameters()); } - return Optional.of(terms.stream() - .collect(Collectors.joining(") " + operator + " (", "(", ")"))); + return Optional.of(new ParameterizedExpression( + terms.stream() + .collect(Collectors.joining(") " + operator + " (", "(", ")")), + parameters.build())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java index 94ada1e9ecd4..7209e787da63 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java @@ -13,10 +13,11 @@ */ package io.trino.plugin.jdbc.expression; -import io.airlift.slice.Slice; +import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; import io.trino.spi.expression.Constant; import io.trino.spi.type.VarcharType; @@ -26,7 +27,7 @@ import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; public class RewriteVarcharConstant - implements ConnectorExpressionRule + implements ConnectorExpressionRule { private static final Pattern PATTERN = constant().with(type().matching(VarcharType.class::isInstance)); @@ -37,12 +38,13 @@ public Pattern getPattern() } @Override - public Optional rewrite(Constant constant, Captures captures, RewriteContext context) + public Optional rewrite(Constant constant, Captures captures, RewriteContext context) { - Slice slice = (Slice) constant.getValue(); - if (slice == null) { + Object value = constant.getValue(); + if (value == null) { + // TODO we could handle NULL values too return Optional.empty(); } - return Optional.of("'" + slice.toStringUtf8().replace("'", "''") + "'"); + return Optional.of(new ParameterizedExpression("?", ImmutableList.of(new QueryParameter(constant.getType(), Optional.of(value))))); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVariable.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVariable.java index a4c9697043a9..db01a1c6d649 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVariable.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVariable.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc.expression; +import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.ConnectorExpressionRule; @@ -26,7 +27,7 @@ import static java.util.Objects.requireNonNull; public class RewriteVariable - implements ConnectorExpressionRule + implements ConnectorExpressionRule { private final Function identifierQuote; @@ -42,9 +43,9 @@ public Pattern getPattern() } @Override - public Optional rewrite(Variable variable, Captures captures, RewriteContext context) + public Optional rewrite(Variable variable, Captures captures, RewriteContext context) { JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(variable.getName()); - return Optional.of(identifierQuote.apply(columnHandle.getColumnName())); + return Optional.of(new ParameterizedExpression(identifierQuote.apply(columnHandle.getColumnName()), ImmutableList.of())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index 0e26d23bd741..45d26de44ad8 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -27,6 +27,7 @@ import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; @@ -147,7 +148,7 @@ public Optional implementAggregation(ConnectorSession session, A } @Override - public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { return stats.getConvertPredicate().wrap(() -> delegate().convertPredicate(session, expression, assignments)); } @@ -178,7 +179,7 @@ public PreparedQuery prepareQuery( JdbcTableHandle table, Optional>> groupingSets, List columns, - Map columnExpressions) + Map columnExpressions) { return stats.getPrepareQuery().wrap(() -> delegate().prepareQuery(session, table, groupingSets, columns, columnExpressions)); } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java index 2302c390ec9c..f0d955232337 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableMultiset; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multiset; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; @@ -599,7 +600,7 @@ public void testAggregation() TEST_TABLE, Optional.of(ImmutableList.of(ImmutableList.of(this.columns.get(2)))), projectedColumns, - Map.of("s", "sum(\"col_0\")"), + Map.of("s", new ParameterizedExpression("sum(\"col_0\")", List.of())), TupleDomain.all(), Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(jdbcClient, SESSION, connection, preparedQuery)) { @@ -642,7 +643,7 @@ public void testAggregationWithFilter() TEST_TABLE, Optional.of(ImmutableList.of(ImmutableList.of(this.columns.get(2)))), projectedColumns, - Map.of("s", "sum(\"col_0\")"), + Map.of("s", new ParameterizedExpression("sum(\"col_0\")", List.of())), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(jdbcClient, SESSION, connection, preparedQuery)) { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestGenericRewrite.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestGenericRewrite.java index 805b0c7c80be..45357eb61735 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestGenericRewrite.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestGenericRewrite.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc.expression; +import com.google.common.collect.ImmutableList; import io.trino.matching.Match; import io.trino.plugin.base.expression.ConnectorExpressionRule.RewriteContext; import io.trino.spi.connector.ColumnHandle; @@ -48,8 +49,9 @@ public void testRewriteCall() new Variable("first", createDecimalType(10, 2)), new Variable("second", BIGINT))); - Optional rewritten = apply(rewrite, expression); - assertThat(rewritten).hasValue("(\"first\") + (\"second\")::decimal(21,2)"); + ParameterizedExpression rewritten = apply(rewrite, expression).orElseThrow(); + assertThat(rewritten.expression()).isEqualTo("(\"first\") + (\"second\")::decimal(21,2)"); + assertThat(rewritten.parameters()).isEqualTo(List.of()); } @Test @@ -63,8 +65,9 @@ public void testRewriteCallWithTypeClass() new FunctionName("add"), List.of( new Variable("first", INTEGER), - new Variable("second", BIGINT))))) - .hasValue("(\"first\") + (\"second\")"); + new Variable("second", BIGINT)))) + .orElseThrow().expression()) + .isEqualTo("(\"first\") + (\"second\")"); // argument type not in class assertThat(apply(rewrite, new Call( @@ -85,7 +88,7 @@ public void testRewriteCallWithTypeClass() .isEmpty(); } - private static Optional apply(GenericRewrite rewrite, ConnectorExpression expression) + private static Optional apply(GenericRewrite rewrite, ConnectorExpression expression) { Optional match = rewrite.getPattern().match(expression).collect(toOptional()); if (match.isEmpty()) { @@ -106,10 +109,10 @@ public ConnectorSession getSession() } @Override - public Optional defaultRewrite(ConnectorExpression expression1) + public Optional defaultRewrite(ConnectorExpression expression) { - if (expression1 instanceof Variable) { - return Optional.of("\"" + ((Variable) expression1).getName().replace("\"", "\"\"") + "\""); + if (expression instanceof Variable) { + return Optional.of(new ParameterizedExpression("\"" + ((Variable) expression).getName().replace("\"", "\"\"") + "\"", ImmutableList.of())); } return Optional.empty(); } diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index c0cdbb263418..f48286452ca4 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -48,6 +48,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementMinMax; import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -194,8 +195,8 @@ public class ClickHouseClient public static final int DEFAULT_DOMAIN_COMPACTION_THRESHOLD = 1_000; - private final ConnectorExpressionRewriter connectorExpressionRewriter; - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; private final Type uuidType; private final Type ipAddressType; private final AtomicReference clickHouseVersion = new AtomicReference<>(); @@ -218,7 +219,7 @@ public ClickHouseClient( .build(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( this.connectorExpressionRewriter, - ImmutableSet.>builder() + ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) .add(new ImplementMinMax(false)) // TODO: Revisit once https://github.com/trinodb/trino/issues/7100 is resolved diff --git a/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java b/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java index 7ff2fd2a5442..5aec879b78ad 100644 --- a/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java +++ b/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java @@ -30,6 +30,7 @@ import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -371,7 +372,14 @@ private static boolean hasSecondPrecision(long epochMicros) } @Override - protected PreparedQuery prepareQuery(ConnectorSession session, Connection connection, JdbcTableHandle table, Optional>> groupingSets, List columns, Map columnExpressions, Optional split) + protected PreparedQuery prepareQuery( + ConnectorSession session, + Connection connection, + JdbcTableHandle table, + Optional>> groupingSets, + List columns, + Map columnExpressions, + Optional split) { return super.prepareQuery(session, connection, prepareTableHandleForQuery(table), groupingSets, columns, columnExpressions, split); } diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java index d61ea00a928b..92e12024793f 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java @@ -44,6 +44,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementMinMax; import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -148,8 +149,8 @@ public class IgniteClient private static final LocalDate MIN_DATE = LocalDate.parse("1970-01-01"); private static final LocalDate MAX_DATE = LocalDate.parse("9999-12-31"); - private final ConnectorExpressionRewriter connectorExpressionRewriter; - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject public IgniteClient( @@ -169,7 +170,7 @@ public IgniteClient( .build(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( connectorExpressionRewriter, - ImmutableSet.>builder() + ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) .add(new ImplementMinMax(true)) @@ -315,7 +316,7 @@ public Optional implementAggregation(ConnectorSession session, A } @Override - public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { return connectorExpressionRewriter.rewrite(session, expression, assignments); } diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/ImplementAvgDecimal.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/ImplementAvgDecimal.java index 8483e4b12aeb..77dc199ccd1f 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/ImplementAvgDecimal.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/ImplementAvgDecimal.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DecimalType; @@ -38,7 +39,7 @@ * Implements {@code avg(decimal(p, s)} */ public class ImplementAvgDecimal - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -54,7 +55,7 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); @@ -63,7 +64,10 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType().equals(type)); // wait https://issues.apache.org/jira/browse/IGNITE-14948 to be solved. + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("CAST(sum(%s) / count(%1$s) AS decimal(%s, %s))", context.rewriteExpression(argument).orElseThrow(), type.getPrecision() + 1, type.getScale()), columnHandle.getJdbcTypeHandle())); + format("CAST(sum(%s) / count(%1$s) AS decimal(%s, %s))", rewrittenArgument.expression(), type.getPrecision() + 1, type.getScale()), + rewrittenArgument.parameters(), + columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java index 264eb7d1349b..cfb9de80968c 100644 --- a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java +++ b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java @@ -43,6 +43,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -155,7 +156,7 @@ public class MariaDbClient // MariaDB Error Codes https://mariadb.com/kb/en/mariadb-error-codes/ private static final int PARSE_ERROR = 1064; - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) @@ -163,12 +164,12 @@ public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, super(config, "`", connectionFactory, queryBuilder, identifierMapping, queryModifier); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) .build(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( connectorExpressionRewriter, - ImmutableSet.>builder() + ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) .add(new ImplementMinMax(false)) diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java index 7423209b93b6..0ceece5bcd90 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java @@ -51,6 +51,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -194,8 +195,8 @@ public class MySqlClient private final Type jsonType; private final boolean statisticsEnabled; - private final ConnectorExpressionRewriter connectorExpressionRewriter; - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject public MySqlClient( @@ -218,7 +219,7 @@ public MySqlClient( JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( this.connectorExpressionRewriter, - ImmutableSet.>builder() + ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) .add(new ImplementMinMax(false)) diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java index aa155f66da10..a03948801eeb 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java @@ -49,6 +49,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -191,8 +192,8 @@ public class OracleClient .buildOrThrow(); private final boolean synonymsEnabled; - private final ConnectorExpressionRewriter connectorExpressionRewriter; - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject public OracleClient( @@ -214,7 +215,7 @@ public OracleClient( JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(TRINO_BIGINT_TYPE, Optional.of("NUMBER"), Optional.of(0), Optional.of(0), Optional.empty(), Optional.empty()); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( connectorExpressionRewriter, - ImmutableSet.>builder() + ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) .add(new ImplementCountDistinct(bigintTypeHandle, true)) diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index e6c1e2c2e93d..795c74f4169d 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -65,6 +65,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.expression.RewriteComparison; import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; @@ -272,8 +273,8 @@ public class PostgreSqlClient private final MapType varcharMapType; private final List tableTypes; private final boolean statisticsEnabled; - private final ConnectorExpressionRewriter connectorExpressionRewriter; - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject public PostgreSqlClient( @@ -323,7 +324,7 @@ public PostgreSqlClient( JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( this.connectorExpressionRewriter, - ImmutableSet.>builder() + ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementMinMax(false)) .add(new ImplementCount(bigintTypeHandle)) @@ -749,7 +750,7 @@ public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHa } @Override - public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { return connectorExpressionRewriter.rewrite(session, expression, assignments); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index d2f5b1efafaa..9fb4eaff9688 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -24,6 +24,8 @@ import io.trino.plugin.jdbc.JdbcMetadataSessionProperties; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; @@ -62,6 +64,7 @@ import java.util.stream.Stream; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.airlift.slice.Slices.utf8Slice; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -218,66 +221,75 @@ private static void testImplementAggregation(AggregateFunction aggregateFunction @Test public void testConvertOr() { - assertThat(JDBC_CLIENT.convertPredicate( - SESSION, - translateToConnectorExpression( - new LogicalExpression( - LogicalExpression.Operator.OR, - List.of( - new ComparisonExpression( - ComparisonExpression.Operator.EQUAL, - new SymbolReference("c_bigint_symbol"), - LITERAL_ENCODER.toExpression(TEST_SESSION, 42L, BIGINT)), - new ComparisonExpression( - ComparisonExpression.Operator.EQUAL, - new SymbolReference("c_bigint_symbol_2"), - LITERAL_ENCODER.toExpression(TEST_SESSION, 415L, BIGINT)))), + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new LogicalExpression( + LogicalExpression.Operator.OR, + List.of( + new ComparisonExpression( + ComparisonExpression.Operator.EQUAL, + new SymbolReference("c_bigint_symbol"), + LITERAL_ENCODER.toExpression(TEST_SESSION, 42L, BIGINT)), + new ComparisonExpression( + ComparisonExpression.Operator.EQUAL, + new SymbolReference("c_bigint_symbol_2"), + LITERAL_ENCODER.toExpression(TEST_SESSION, 415L, BIGINT)))), + Map.of( + "c_bigint_symbol", BIGINT, + "c_bigint_symbol_2", BIGINT)), Map.of( - "c_bigint_symbol", BIGINT, - "c_bigint_symbol_2", BIGINT)), - Map.of( - "c_bigint_symbol", BIGINT_COLUMN, - "c_bigint_symbol_2", BIGINT_COLUMN))) - .hasValue("((\"c_bigint\") = (42)) OR ((\"c_bigint\") = (415))"); + "c_bigint_symbol", BIGINT_COLUMN, + "c_bigint_symbol_2", BIGINT_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("((\"c_bigint\") = (?)) OR ((\"c_bigint\") = (?))"); + assertThat(converted.parameters()).isEqualTo(List.of( + new QueryParameter(BIGINT, Optional.of(42L)), + new QueryParameter(BIGINT, Optional.of(415L)))); } @Test public void testConvertOrWithAnd() { - assertThat(JDBC_CLIENT.convertPredicate( - SESSION, - translateToConnectorExpression( - new LogicalExpression( - LogicalExpression.Operator.OR, - List.of( - new ComparisonExpression( - ComparisonExpression.Operator.EQUAL, - new SymbolReference("c_bigint_symbol"), - LITERAL_ENCODER.toExpression(TEST_SESSION, 42L, BIGINT)), - new LogicalExpression( - LogicalExpression.Operator.AND, - List.of( - new ComparisonExpression( - ComparisonExpression.Operator.EQUAL, - new SymbolReference("c_bigint_symbol"), - LITERAL_ENCODER.toExpression(TEST_SESSION, 43L, BIGINT)), - new ComparisonExpression( - ComparisonExpression.Operator.EQUAL, - new SymbolReference("c_bigint_symbol_2"), - LITERAL_ENCODER.toExpression(TEST_SESSION, 44L, BIGINT)))))), + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new LogicalExpression( + LogicalExpression.Operator.OR, + List.of( + new ComparisonExpression( + ComparisonExpression.Operator.EQUAL, + new SymbolReference("c_bigint_symbol"), + LITERAL_ENCODER.toExpression(TEST_SESSION, 42L, BIGINT)), + new LogicalExpression( + LogicalExpression.Operator.AND, + List.of( + new ComparisonExpression( + ComparisonExpression.Operator.EQUAL, + new SymbolReference("c_bigint_symbol"), + LITERAL_ENCODER.toExpression(TEST_SESSION, 43L, BIGINT)), + new ComparisonExpression( + ComparisonExpression.Operator.EQUAL, + new SymbolReference("c_bigint_symbol_2"), + LITERAL_ENCODER.toExpression(TEST_SESSION, 44L, BIGINT)))))), + Map.of( + "c_bigint_symbol", BIGINT, + "c_bigint_symbol_2", BIGINT)), Map.of( - "c_bigint_symbol", BIGINT, - "c_bigint_symbol_2", BIGINT)), - Map.of( - "c_bigint_symbol", BIGINT_COLUMN, - "c_bigint_symbol_2", BIGINT_COLUMN))) - .hasValue("((\"c_bigint\") = (42)) OR (((\"c_bigint\") = (43)) AND ((\"c_bigint\") = (44)))"); + "c_bigint_symbol", BIGINT_COLUMN, + "c_bigint_symbol_2", BIGINT_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("((\"c_bigint\") = (?)) OR (((\"c_bigint\") = (?)) AND ((\"c_bigint\") = (?)))"); + assertThat(converted.parameters()).isEqualTo(List.of( + new QueryParameter(BIGINT, Optional.of(42L)), + new QueryParameter(BIGINT, Optional.of(43L)), + new QueryParameter(BIGINT, Optional.of(44L)))); } @Test(dataProvider = "testConvertComparisonDataProvider") public void testConvertComparison(ComparisonExpression.Operator operator) { - Optional converted = JDBC_CLIENT.convertPredicate( + Optional converted = JDBC_CLIENT.convertPredicate( SESSION, translateToConnectorExpression( new ComparisonExpression( @@ -290,7 +302,9 @@ public void testConvertComparison(ComparisonExpression.Operator operator) switch (operator) { case EQUAL: case NOT_EQUAL: - assertThat(converted).hasValue(format("(\"c_bigint\") %s (42)", operator.getValue())); + assertThat(converted).isPresent(); + assertThat(converted.get().expression()).isEqualTo(format("(\"c_bigint\") %s (?)", operator.getValue())); + assertThat(converted.get().parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(42L)))); return; case LESS_THAN: case LESS_THAN_OR_EQUAL: @@ -314,17 +328,19 @@ public static Object[][] testConvertComparisonDataProvider() @Test(dataProvider = "testConvertArithmeticBinaryDataProvider") public void testConvertArithmeticBinary(ArithmeticBinaryExpression.Operator operator) { - Optional converted = JDBC_CLIENT.convertPredicate( - SESSION, - translateToConnectorExpression( - new ArithmeticBinaryExpression( - operator, - new SymbolReference("c_bigint_symbol"), - LITERAL_ENCODER.toExpression(TEST_SESSION, 42L, BIGINT)), - Map.of("c_bigint_symbol", BIGINT)), - Map.of("c_bigint_symbol", BIGINT_COLUMN)); + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new ArithmeticBinaryExpression( + operator, + new SymbolReference("c_bigint_symbol"), + LITERAL_ENCODER.toExpression(TEST_SESSION, 42L, BIGINT)), + Map.of("c_bigint_symbol", BIGINT)), + Map.of("c_bigint_symbol", BIGINT_COLUMN)) + .orElseThrow(); - assertThat(converted).hasValue(format("(\"c_bigint\") %s (42)", operator.getValue())); + assertThat(converted.expression()).isEqualTo(format("(\"c_bigint\") %s (?)", operator.getValue())); + assertThat(converted.parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(42L)))); } @DataProvider @@ -337,110 +353,131 @@ public static Object[][] testConvertArithmeticBinaryDataProvider() @Test public void testConvertArithmeticUnaryMinus() { - Optional converted = JDBC_CLIENT.convertPredicate( - SESSION, - translateToConnectorExpression( - new ArithmeticUnaryExpression( - ArithmeticUnaryExpression.Sign.MINUS, - new SymbolReference("c_bigint_symbol")), - Map.of("c_bigint_symbol", BIGINT)), - Map.of("c_bigint_symbol", BIGINT_COLUMN)); + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new ArithmeticUnaryExpression( + ArithmeticUnaryExpression.Sign.MINUS, + new SymbolReference("c_bigint_symbol")), + Map.of("c_bigint_symbol", BIGINT)), + Map.of("c_bigint_symbol", BIGINT_COLUMN)) + .orElseThrow(); - assertThat(converted).hasValue("-(\"c_bigint\")"); + assertThat(converted.expression()).isEqualTo("-(\"c_bigint\")"); + assertThat(converted.parameters()).isEqualTo(List.of()); } @Test public void testConvertLike() { // c_varchar LIKE '%pattern%' - assertThat(JDBC_CLIENT.convertPredicate(SESSION, - translateToConnectorExpression( - new LikePredicate( - new SymbolReference("c_varchar_symbol"), - new StringLiteral("%pattern%"), - Optional.empty()), - Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), - Map.of("c_varchar_symbol", VARCHAR_COLUMN))) - .hasValue("(\"c_varchar\") LIKE ('%pattern%')"); + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new LikePredicate( + new SymbolReference("c_varchar_symbol"), + new StringLiteral("%pattern%"), + Optional.empty()), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("(\"c_varchar\") LIKE (?)"); + assertThat(converted.parameters()).isEqualTo(List.of( + new QueryParameter(createVarcharType(9), Optional.of(utf8Slice("%pattern%"))))); // c_varchar LIKE '%pattern\%' ESCAPE '\' - assertThat(JDBC_CLIENT.convertPredicate(SESSION, - translateToConnectorExpression( - new LikePredicate( - new SymbolReference("c_varchar"), - new StringLiteral("%pattern\\%"), - new StringLiteral("\\")), - Map.of("c_varchar", VARCHAR_COLUMN.getColumnType())), - Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN))) - .hasValue("(\"c_varchar\") LIKE ('%pattern\\%') ESCAPE ('\\')"); + converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new LikePredicate( + new SymbolReference("c_varchar"), + new StringLiteral("%pattern\\%"), + new StringLiteral("\\")), + Map.of("c_varchar", VARCHAR_COLUMN.getColumnType())), + Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("(\"c_varchar\") LIKE (?) ESCAPE (?)"); + assertThat(converted.parameters()).isEqualTo(List.of( + new QueryParameter(createVarcharType(10), Optional.of(utf8Slice("%pattern\\%"))), + new QueryParameter(createVarcharType(1), Optional.of(utf8Slice("\\"))))); } @Test public void testConvertIsNull() { // c_varchar IS NULL - assertThat(JDBC_CLIENT.convertPredicate(SESSION, - translateToConnectorExpression( - new IsNullPredicate( - new SymbolReference("c_varchar_symbol")), - Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), - Map.of("c_varchar_symbol", VARCHAR_COLUMN))) - .hasValue("(\"c_varchar\") IS NULL"); + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new IsNullPredicate( + new SymbolReference("c_varchar_symbol")), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("(\"c_varchar\") IS NULL"); + assertThat(converted.parameters()).isEqualTo(List.of()); } @Test public void testConvertIsNotNull() { // c_varchar IS NOT NULL - assertThat(JDBC_CLIENT.convertPredicate(SESSION, - translateToConnectorExpression( - new IsNotNullPredicate( - new SymbolReference("c_varchar_symbol")), - Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), - Map.of("c_varchar_symbol", VARCHAR_COLUMN))) - .hasValue("(\"c_varchar\") IS NOT NULL"); + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new IsNotNullPredicate( + new SymbolReference("c_varchar_symbol")), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("(\"c_varchar\") IS NOT NULL"); + assertThat(converted.parameters()).isEqualTo(List.of()); } @Test public void testConvertNullIf() { // nullif(a_varchar, b_varchar) - assertThat(JDBC_CLIENT.convertPredicate(SESSION, - translateToConnectorExpression( - new NullIfExpression( - new SymbolReference("a_varchar_symbol"), - new SymbolReference("b_varchar_symbol")), - ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN.getColumnType(), "b_varchar_symbol", VARCHAR_COLUMN.getColumnType())), - ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN, "b_varchar_symbol", VARCHAR_COLUMN))) - .hasValue("NULLIF((\"c_varchar\"), (\"c_varchar\"))"); + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new NullIfExpression( + new SymbolReference("a_varchar_symbol"), + new SymbolReference("b_varchar_symbol")), + ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN.getColumnType(), "b_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN, "b_varchar_symbol", VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("NULLIF((\"c_varchar\"), (\"c_varchar\"))"); + assertThat(converted.parameters()).isEqualTo(List.of()); } @Test public void testConvertNotExpression() { // NOT(expression) - assertThat(JDBC_CLIENT.convertPredicate(SESSION, - translateToConnectorExpression( - new NotExpression( - new IsNotNullPredicate( - new SymbolReference("c_varchar_symbol"))), - Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), - Map.of("c_varchar_symbol", VARCHAR_COLUMN))) - .hasValue("NOT ((\"c_varchar\") IS NOT NULL)"); + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new NotExpression( + new IsNotNullPredicate( + new SymbolReference("c_varchar_symbol"))), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("NOT ((\"c_varchar\") IS NOT NULL)"); + assertThat(converted.parameters()).isEqualTo(List.of()); } @Test public void testConvertIn() { - assertThat(JDBC_CLIENT.convertPredicate( - SESSION, - translateToConnectorExpression( + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( new InPredicate( new SymbolReference("c_varchar"), new InListExpression(List.of(new StringLiteral("value1"), new StringLiteral("value2"), new SymbolReference("c_varchar2")))), Map.of("c_varchar", VARCHAR_COLUMN.getColumnType(), "c_varchar2", VARCHAR_COLUMN2.getColumnType())), - Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN, VARCHAR_COLUMN2.getColumnName(), VARCHAR_COLUMN2))) - .hasValue("(\"c_varchar\") IN ('value1', 'value2', \"c_varchar2\")"); + Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN, VARCHAR_COLUMN2.getColumnName(), VARCHAR_COLUMN2)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("(\"c_varchar\") IN (?, ?, \"c_varchar2\")"); + assertThat(converted.parameters()).isEqualTo(List.of( + new QueryParameter(createVarcharType(6), Optional.of(utf8Slice("value1"))), + new QueryParameter(createVarcharType(6), Optional.of(utf8Slice("value2"))))); } private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes) diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java index 103258db12b7..4b8fa78c330d 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DecimalType; @@ -36,7 +37,7 @@ import static java.lang.String.format; public class ImplementRedshiftAvgDecimal - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture INPUT = newCapture(); @@ -52,24 +53,28 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable input = captures.get(INPUT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); DecimalType type = (DecimalType) columnHandle.getColumnType(); verify(aggregateFunction.getOutputType().equals(type)); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(input).orElseThrow(); + // When decimal type has maximum precision we can get result that is not matching Presto avg semantics. if (type.getPrecision() == REDSHIFT_MAX_DECIMAL_PRECISION) { return Optional.of(new JdbcExpression( - format("avg(CAST(%s AS decimal(%s, %s)))", context.rewriteExpression(input).orElseThrow(), type.getPrecision(), type.getScale()), + format("avg(CAST(%s AS decimal(%s, %s)))", rewrittenArgument.expression(), type.getPrecision(), type.getScale()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } // Redshift avg function rounds down resulting decimal. // To match Presto avg semantics, we extend scale by 1 and round result to target scale. return Optional.of(new JdbcExpression( - format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", context.rewriteExpression(input).orElseThrow(), type.getPrecision() + 1, type.getScale() + 1, type.getScale()), + format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", rewrittenArgument.expression(), type.getPrecision() + 1, type.getScale() + 1, type.getScale()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 61975c21a47d..d391e724a111 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -52,6 +52,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -226,7 +227,7 @@ public class RedshiftClient .toFormatter(); private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; private final boolean statisticsEnabled; private final RedshiftTableStatisticsReader statisticsReader; private final boolean legacyTypeMapping; @@ -243,7 +244,7 @@ public RedshiftClient( { super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); this.legacyTypeMapping = redshiftConfig.isLegacyTypeMapping(); - ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) .build(); @@ -251,7 +252,7 @@ public RedshiftClient( aggregateFunctionRewriter = new AggregateFunctionRewriter<>( connectorExpressionRewriter, - ImmutableSet.>builder() + ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) .add(new ImplementCountDistinct(bigintTypeHandle, true)) diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerCountBig.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerCountBig.java index 9c11f39073b7..7844d2414128 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerCountBig.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerCountBig.java @@ -18,6 +18,7 @@ import io.trino.matching.Pattern; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; @@ -37,7 +38,7 @@ * Implements specialized version of {@code count(x)} that returns bigint in SQL Server. */ public class ImplementSqlServerCountBig - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -50,13 +51,15 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); verify(aggregateFunction.getOutputType() == BIGINT); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("count_big(%s)", context.rewriteExpression(argument).orElseThrow()), + format("count_big(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), BIGINT_TYPE)); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerCountBigAll.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerCountBigAll.java index 90318e27779f..6d29acf9ec10 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerCountBigAll.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerCountBigAll.java @@ -13,10 +13,12 @@ */ package io.trino.plugin.sqlserver; +import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import java.util.List; @@ -33,7 +35,7 @@ * Implements specialized version of {@code count(*)} that returns bigint in SQL Server. */ public class ImplementSqlServerCountBigAll - implements AggregateFunctionRule + implements AggregateFunctionRule { @Override public Pattern getPattern() @@ -44,9 +46,12 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { verify(aggregateFunction.getOutputType() == BIGINT); - return Optional.of(new JdbcExpression("count_big(*)", BIGINT_TYPE)); + return Optional.of(new JdbcExpression( + "count_big(*)", + ImmutableList.of(), + BIGINT_TYPE)); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java index 16bc8c5a4bfe..1463fc57d4a2 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -36,7 +37,7 @@ import static java.lang.String.format; public class ImplementSqlServerStddevPop - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -52,15 +53,17 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(columnHandle.getColumnType().equals(DOUBLE)); verify(aggregateFunction.getOutputType().equals(DOUBLE)); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("STDEVP(%s)", context.rewriteExpression(argument).orElseThrow()), + format("STDEVP(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java index 4ce740d2f4e9..f153ac4d935e 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -36,7 +37,7 @@ import static java.lang.String.format; public class ImplementSqlServerStdev - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -52,15 +53,17 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(columnHandle.getColumnType().equals(DOUBLE)); verify(aggregateFunction.getOutputType().equals(DOUBLE)); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("STDEV(%s)", context.rewriteExpression(argument).orElseThrow()), + format("STDEV(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java index f626629e0589..9333b4ca13a2 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -36,7 +37,7 @@ import static java.lang.String.format; public class ImplementSqlServerVariance - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -52,15 +53,17 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(columnHandle.getColumnType().equals(DOUBLE)); verify(aggregateFunction.getOutputType().equals(DOUBLE)); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("VAR(%s)", context.rewriteExpression(argument).orElseThrow()), + format("VAR(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java index 201e92736f4c..18ccf9c1dd24 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java @@ -19,6 +19,7 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -36,7 +37,7 @@ import static java.lang.String.format; public class ImplementSqlServerVariancePop - implements AggregateFunctionRule + implements AggregateFunctionRule { private static final Capture ARGUMENT = newCapture(); @@ -52,15 +53,17 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(columnHandle.getColumnType().equals(DOUBLE)); verify(aggregateFunction.getOutputType().equals(DOUBLE)); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of(new JdbcExpression( - format("VARP(%s)", context.rewriteExpression(argument).orElseThrow()), + format("VARP(%s)", rewrittenArgument.expression()), + rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java deleted file mode 100644 index b31aa6bfa0c8..000000000000 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.sqlserver; - -import com.google.common.base.CharMatcher; -import io.airlift.slice.Slice; -import io.trino.matching.Captures; -import io.trino.matching.Pattern; -import io.trino.plugin.base.expression.ConnectorExpressionRule; -import io.trino.spi.expression.Constant; -import io.trino.spi.type.VarcharType; - -import java.util.Optional; - -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.constant; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; - -public class RewriteUnicodeVarcharConstant - implements ConnectorExpressionRule -{ - private static final Pattern PATTERN = constant().with(type().matching(VarcharType.class::isInstance)); - private static final CharMatcher UNICODE_CHARACTER_MATCHER = CharMatcher.ascii().negate().precomputed(); - - @Override - public Pattern getPattern() - { - return PATTERN; - } - - @Override - public Optional rewrite(Constant constant, Captures captures, RewriteContext context) - { - Slice slice = (Slice) constant.getValue(); - if (slice == null) { - return Optional.empty(); - } - - String sliceUtf8String = slice.toStringUtf8(); - boolean isUnicodeString = UNICODE_CHARACTER_MATCHER.matchesAnyOf(sliceUtf8String); - - if (isUnicodeString) { - return Optional.of("N'" + sliceUtf8String.replace("'", "''") + "'"); - } - - return Optional.of("'" + sliceUtf8String.replace("'", "''") + "'"); - } -} diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 4d666570add2..c7358505f9b1 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -57,6 +57,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementMinMax; import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.expression.RewriteComparison; import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; @@ -215,8 +216,8 @@ public class SqlServerClient private final boolean statisticsEnabled; - private final ConnectorExpressionRewriter connectorExpressionRewriter; - private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; private static final int MAX_SUPPORTED_TEMPORAL_PRECISION = 7; @@ -254,9 +255,6 @@ public SqlServerClient( this.statisticsEnabled = statisticsConfig.isEnabled(); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() - // Only SqlServer requires N prefix for unicode characters (SQL-92 standard), - // so we add this rule to support such cases for pushdowns - .add(new RewriteUnicodeVarcharConstant()) .addStandardRules(this::quoted) .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) .add(new RewriteIn()) @@ -277,7 +275,7 @@ public SqlServerClient( this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( this.connectorExpressionRewriter, - ImmutableSet.>builder() + ImmutableSet.>builder() .add(new ImplementSqlServerCountBigAll()) .add(new ImplementSqlServerCountBig()) .add(new ImplementMinMax(false)) @@ -404,7 +402,7 @@ public void setColumnType(ConnectorSession session, JdbcTableHandle handle, Jdbc } @Override - public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { return connectorExpressionRewriter.rewrite(session, expression, assignments); }